In [1]:
import torch
from torch import nn
from torch.utils import data
from torch.cuda.amp import autocast, GradScaler

import numpy as np
from tqdm.notebook import tqdm

import sys
import datetime

sys.path.append('../code')
from dataset import get_data, make_vocab, WantWordsDataset as WWDataset

from transformers import (
    AdamW, get_linear_schedule_with_warmup
)

from models import SentenceBERTForRD

In [2]:
d = get_data('../wantwords-english-baseline/data', word2vec=True)

Loading data...
word2vec: 75099 vectors
Training data: 675715 word-def pairs
Dev data: 75873 word-def pairs
Test data: 1200 word-def pairs


In [3]:
train_data, train_data_def, dev_data, test_data_seen, \
    test_data_unseen, test_data_desc, word2vec = d

In [4]:
target2idx, idx2target = make_vocab(d[:-1], None)

In [5]:
# target2idx maps target words to indices
# target_matrix maps target indices to bpe sequences, padded/truncated to mask_size
target2idx['book'], idx2target[target2idx['book']]

(16187, 'book')

In [6]:
# can freeze for (part of) first epoch or so and then unfreeze to train the whole model
model = SentenceBERTForRD('distilbert-base-nli-stsb-mean-tokens', 
                          300, freeze_sbert=False, criterion=nn.MSELoss())

In [7]:
T = model.sbert.tokenizer
train_dataset = WWDataset(train_data + train_data_def, T, target2idx, word2vec)
dev_dataset = WWDataset(dev_data, T, target2idx, word2vec)
test_dataset_seen = WWDataset(test_data_seen, T, target2idx, word2vec)
test_dataset_unseen = WWDataset(test_data_unseen, T, target2idx, word2vec)
test_dataset_desc = WWDataset(test_data_desc, T, target2idx, word2vec)

In [8]:
batch_size = 55
num_workers = 0

loader_params = {
    'pin_memory': False,
    'batch_size': batch_size,
    'num_workers': num_workers,
    'collate_fn': train_dataset.collate_fn
}

train_loader = data.DataLoader(train_dataset, **{'shuffle': True, **loader_params})
dev_loader = data.DataLoader(dev_dataset, **{'shuffle': True, **loader_params})
test_loader_seen = data.DataLoader(test_dataset_seen, **{'shuffle': False, **loader_params})
test_loader_unseen = data.DataLoader(test_dataset_unseen, **{'shuffle': False, **loader_params})
test_loader_desc = data.DataLoader(test_dataset_desc, **{'shuffle': False, **loader_params})

In [9]:
epochs = 10

lr = 2e-5
optim = AdamW(model.parameters(), lr=lr)

warmup_duration = 0.01 # portion of the first epoch spent on lr warmup
scheduler = get_linear_schedule_with_warmup(optim, num_warmup_steps=len(train_loader) * warmup_duration, 
                                            num_training_steps=len(train_loader) * epochs)

epoch = 0

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# scaler = GradScaler()

In [10]:
import wandb

wandb.init(project='reverse-dictionary', entity='reverse-dict')

config = wandb.config
config.learning_rate = lr
config.epochs = epochs
config.batch_size = batch_size
config.optimizer = type(optim).__name__
config.scheduler = type(scheduler).__name__
config.warmup_duration = warmup_duration

wandb.watch(model)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mreverse-dict[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.29 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


[<wandb.wandb_torch.TorchGraph at 0x7f752726c810>]

In [11]:
model = model.to(device)
word2vec.embeddings = word2vec.embeddings.to(device)

In [12]:
def evaluate(pred, gt, test=False):
    acc1 = acc10 = acc100 = 0
    n = len(pred)
    pred_rank = []
    for p, word in zip(pred, gt):
        if test:
            loc = (p == word).nonzero(as_tuple=True)
            if len(loc) != 0:
                pred_rank.append(min(loc[-1], 1000))
            else:
                pred_rank.append(1000)
        if word in p[:100]:
            acc100 += 1
            if word in p[:10]:
                acc10 += 1
                if word == p[0]:
                    acc1 += 1
    if test:
        pred_rank = torch.tensor(pred_rank, dtype=torch.float32)
        return (acc1, acc10, acc100, pred_rank)
    else:
        return acc1/n, acc10/n, acc100/n

In [13]:
def test(loader, name):
    inc = 3
    model.eval()
    test_loss = 0.0
    test_acc1 = test_acc10 = test_acc100 = 0.0
    total_seen = 0
    all_pred = []
    with torch.no_grad():
        with tqdm(total=len(loader)) as pbar:
            for i, ((x, attention_mask), y, yvecs) in enumerate(loader):
                if i % inc == 0 and i != 0:
                    display_loss = test_loss / i
                    pbar.set_description(f'Test Loss: {display_loss}')

                x = x.to(device)
                attention_mask = attention_mask.to(device)
                y = y.to(device)
                yvecs = yvecs.to(device)

#                 with autocast():
                loss, out = model(input_ids=x, attention_mask=attention_mask,
                                  ground_truth=yvecs)

                test_loss += loss.detach()
            
                out = out @ word2vec.embeddings.T

                result, indices = torch.sort(out, descending=True)
                
                b = len(x)
                acc1, acc10, acc100, pred_rank = evaluate(indices, y, test=True)
                test_acc1 += acc1
                test_acc10 += acc10
                test_acc100 += acc100
                total_seen += b
                all_pred.extend(pred_rank)
                
                del x, y, out, loss
                if i % 20 == 0:
                    torch.cuda.empty_cache()
                    
                pbar.update(1)
    
    test_loss /= len(loader)
    test_acc1 /= total_seen
    test_acc10 /= total_seen
    test_acc100 /= total_seen
    all_pred = torch.tensor(all_pred)
    median = torch.median(all_pred)
    var = torch.var(all_pred)**0.5
    
    print(f'{name}_test_loss:', test_loss)
    print(f'{name}_test_acc1:', test_acc1)
    print(f'{name}_test_acc10:', test_acc10)
    print(f'{name}_test_acc100:', test_acc100)
    print(f'{name}_test_rank_median:', median)
    print(f'{name}_test_rank_variance', var)
    
    return ({
        f'{name}_test_loss': test_loss,
        f'{name}_test_acc1': test_acc1,
        f'{name}_test_acc10': test_acc10,
        f'{name}_test_acc100': test_acc100,
        f'{name}_test_rank_median': median,
        f'{name}_test_rank_variance': var
    })
    

In [14]:
inc = 10
losses = []

for epoch in range(epoch, epochs):
    # Training
    model.train()
    train_loss = 0.0
    # Train on subset of training data to save time
    with tqdm(total=len(train_loader)) as pbar:
        for i, ((x, attention_mask), _, yvecs) in enumerate(train_loader):
            if i % inc == 0 and i != 0:
                display_loss = train_loss / i
                pbar.set_description(f'Epoch {epoch+1}, Train Loss: {train_loss / i}')

            optim.zero_grad()

            x = x.to(device)
            attention_mask = attention_mask.to(device)
            yvecs = yvecs.to(device)
            
            loss, out = model(input_ids=x, attention_mask=attention_mask, 
                              ground_truth=yvecs)

#             scaler.scale(loss).backward()
            loss.backward()
            
#             scaler.unscale_(optim)
            nn.utils.clip_grad_value_(model.parameters(), 5)
            
#             scaler.step(optim)
            optim.step()
#             scaler.update()
            
            train_loss += loss.detach()
            
            scheduler.step()
            
            pbar.update(1)
            
            del x, yvecs, out, loss, attention_mask
            
    model_name = type(model).__name__
    filename = f'../trained_models/{model_name} Epoch {epoch+1} at {datetime.datetime.now()}'.replace(' ', '_')
    with open(filename, 'wb+') as f:
        torch.save({'state_dict': model.state_dict()}, f)
    
    # Validation
    model.eval()
    val_loss = 0.0
    val_acc1, val_acc10, val_acc100 = 0.0, 0.0, 0.0
    with torch.no_grad():
        with tqdm(total=len(dev_loader)) as pbar:
            for i, ((x, attention_mask), y, yvecs) in enumerate(dev_loader):
                if i % inc == 0 and i != 0:
                    display_loss = val_loss / i
                    pbar.set_description(f'Epoch {epoch+1}, Val Loss: {val_loss / i}')

                x = x.to(device)
                attention_mask = attention_mask.to(device)
                y = y.to(device)
                yvecs = yvecs.to(device)

#                 with autocast():
                loss, out = model(input_ids=x, attention_mask=attention_mask,
                                  ground_truth=yvecs)
    
                out = out @ word2vec.embeddings.T

                val_loss += loss.detach()                
                
                result, indices = torch.topk(out, k=100, dim=-1, largest=True, sorted=True)
                
                acc1, acc10, acc100 = evaluate(indices, y)
                val_acc1 += acc1
                val_acc10 += acc10
                val_acc100 += acc100

                del x, y, yvecs, out, loss
                
                pbar.update(1)
    
    wandb.log({
        'train_loss': train_loss / len(train_loader),
        'val_loss': val_loss / len(dev_loader),
        'val_acc1': val_acc1 / len(dev_loader),
        'val_acc10': val_acc10 / len(dev_loader),
        'val_acc100': val_acc100 / len(dev_loader),
        **test(test_loader_seen, 'seen'),
        **test(test_loader_unseen, 'unseen'),
        **test(test_loader_desc, 'desc')
    })
    

  0%|          | 0/12286 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [20]:
def getPredFromDesc(model, desc : str, top_n=10):
    model = model.eval()
    desc = T(desc, return_tensors='pt', padding=True)
    x = desc['input_ids'].to(device)
    attention_mask = desc['attention_mask'].to(device)
    out = model(input_ids=x, attention_mask=attention_mask)
    out = out @ word2vec.embeddings.T
    result, indices = torch.topk(out, k=top_n, dim=-1, largest=True, sorted=True)
    
    indices = indices[0]
    return [word2vec.itos[i] for i in indices], indices
    

In [18]:
test(test_loader_seen, 'seen')

  0%|          | 0/10 [00:00<?, ?it/s]

seen_test_loss: tensor(0.0313, device='cuda:0')
seen_test_acc1: 0.0
seen_test_acc10: 0.0
seen_test_acc100: 0.004
seen_test_rank_median: tensor(1000.)
seen_test_rank_variance tensor(100.8192)


{'seen_test_loss': tensor(0.0313, device='cuda:0'),
 'seen_test_acc1': 0.0,
 'seen_test_acc10': 0.0,
 'seen_test_acc100': 0.004,
 'seen_test_rank_median': tensor(1000.),
 'seen_test_rank_variance': tensor(100.8192)}

In [29]:
test(test_loader_unseen, 'unseen')

  0%|          | 0/10 [00:00<?, ?it/s]

unseen_test_loss: tensor(14.7876, device='cuda:0')
unseen_test_acc1: 0.004
unseen_test_acc10: 0.006
unseen_test_acc100: 0.006
unseen_test_rank_median: tensor(1000.)
unseen_test_rank_variance tensor(82.8673)


{'unseen_test_loss': tensor(14.7876, device='cuda:0'),
 'unseen_test_acc1': 0.004,
 'unseen_test_acc10': 0.006,
 'unseen_test_acc100': 0.006,
 'unseen_test_rank_median': tensor(1000.),
 'unseen_test_rank_variance': tensor(82.8673)}

In [19]:
test(test_loader_desc, 'desc')

  0%|          | 0/4 [00:00<?, ?it/s]

desc_test_loss: tensor(0.0275, device='cuda:0')
desc_test_acc1: 0.0
desc_test_acc10: 0.0
desc_test_acc100: 0.0
desc_test_rank_median: tensor(1000.)
desc_test_rank_variance tensor(0.)


{'desc_test_loss': tensor(0.0275, device='cuda:0'),
 'desc_test_acc1': 0.0,
 'desc_test_acc10': 0.0,
 'desc_test_acc100': 0.0,
 'desc_test_rank_median': tensor(1000.),
 'desc_test_rank_variance': tensor(0.)}

In [21]:
getPredFromDesc(model, 'an inhabitant of a cold country', 100)

(['permafrost',
  'tundra',
  'snowpack',
  'glaciers',
  'pikas',
  'snowing',
  'microclimate',
  'timbered',
  'snows',
  'steppe',
  'unspoilt',
  'alps',
  'massif',
  'olp',
  'steppes',
  'showery',
  'arctic',
  'snowfalls',
  'wintry',
  'lynx',
  'subzero',
  'treeless',
  'snowfall',
  'outdoorsy',
  'crevasses',
  'morels',
  'moors',
  'snowfield',
  'winters',
  'glacier',
  'snowy',
  'snow',
  'glaciation',
  'obigation',
  'marmot',
  'alpine',
  'overwinter',
  'taiga',
  'igloo',
  'frost',
  'lichen',
  'pika',
  'ptarmigan',
  'frostbitten',
  'snowshoes',
  'frosts',
  'fjord',
  'meltwater',
  'anticyclone',
  'primroses',
  'crags',
  'tannic',
  'crag',
  'glades',
  'mossy',
  'birches',
  'subarctic',
  'rustic',
  'moose',
  'hoppy',
  'icy',
  'subfreezing',
  'moister',
  'aspen',
  'precipitation',
  'coldest',
  'yurt',
  'lichens',
  'yeti',
  'snowcapped',
  'snowshoe',
  'colder',
  'terroir',
  'parka',
  'bluebells',
  'sleet',
  'overwintering',
  

In [22]:
getPredFromDesc(model, 'employee at a circus', 100)

(['panto',
  'vaudeville',
  'ceilidh',
  'pantomime',
  'tearoom',
  'governess',
  'busker',
  'funfair',
  'projectionist',
  'cabaret',
  'speakeasy',
  'vaudevillian',
  'burlesque',
  'taxidermist',
  'stagehand',
  'nobleman',
  'minstrel',
  'ventriloquist',
  'toff',
  'bullfight',
  'taxidermy',
  'showgirl',
  'costuming',
  'puppeteer',
  'boardinghouse',
  'proscenium',
  'aerialist',
  'marionette',
  'taverna',
  'magician',
  'geisha',
  'footman',
  'milliner',
  'carny',
  'garret',
  'madwoman',
  'circus',
  'stagehands',
  'cowhand',
  'seance',
  'manservant',
  'noblewoman',
  'kabuki',
  'blacksmith',
  'circuses',
  'bordello',
  'pantomimes',
  'fayre',
  'trattoria',
  'juggler',
  'showman',
  'ventriloquism',
  'séance',
  'bullfights',
  'nunnery',
  'mime',
  'schoolmaster',
  'wench',
  'huntsman',
  'troupe',
  'dowager',
  'trapeze',
  'scherzo',
  'hostelry',
  'haberdashery',
  'coloratura',
  'marionettes',
  'nativity',
  'roadhouse',
  'boozer',
 

In [23]:
getPredFromDesc(model, 'a road on which cars can go fast', 100)

(['chicane',
  'motorcade',
  'understeer',
  'downforce',
  'oversteer',
  'cortege',
  'understeering',
  'laps',
  'speedometer',
  'swerved',
  'hearse',
  'autocross',
  'eastbound',
  'suplex',
  'southbound',
  'stoplight',
  'braking',
  'mph',
  'swerving',
  'blinker',
  'fishtailed',
  'esses',
  'northbound',
  'jeep',
  'swerve',
  'hairpin',
  'wheelbase',
  'freeway',
  'dragster',
  'carriageway',
  'motorcyclist',
  'lapper',
  'fishtail',
  'fastball',
  'backstretch',
  'lanes',
  'speeder',
  'guardrail',
  'driverless',
  'windscreen',
  'rockpile',
  'headlight',
  'cruiser',
  'convoy',
  'lane',
  'superspeedway',
  'sunroof',
  'sedan',
  'layup',
  'headlights',
  'expressway',
  'dragstrip',
  'camber',
  'jackknifed',
  'criterium',
  'kart',
  'honked',
  'carjack',
  'grounder',
  'contraflow',
  'streamliner',
  'straightaway',
  'motorway',
  'gearshift',
  'powerslam',
  'roadway',
  'taxied',
  'jaywalk',
  'oversteering',
  'rickshaw',
  'taxiway',
  

In [24]:
getPredFromDesc(model, 'something you use to measure your temperature', 100)

(['millibars',
  'refrigerate',
  'temperature',
  'thermometer',
  'preheat',
  'thermostat',
  'temperatures',
  'thermometers',
  'oven',
  'thermocouple',
  'convection',
  'troposphere',
  'dehumidifier',
  'humidifier',
  'altimeter',
  'defrost',
  'psi',
  'hygrometer',
  'saucepan',
  'conductivity',
  'magma',
  'snowpack',
  'deg',
  'thermocouples',
  'ionization',
  'precipitation',
  'humidity',
  'evaporator',
  'albedo',
  'voltages',
  'caramelize',
  'heater',
  'moisture',
  'ozone',
  'supercooled',
  'pressurization',
  'defroster',
  'coldest',
  'skillet',
  'anemometer',
  'meltwater',
  'tsp',
  'spectrometer',
  'reflectance',
  'refrigerating',
  'viscosity',
  'radon',
  'hydrometer',
  'inductance',
  'sunspot',
  'climatologist',
  'flowmeter',
  'coolant',
  'argon',
  'watts',
  'airflow',
  'qubit',
  'workpiece',
  'warmest',
  'refrigerant',
  'microclimate',
  'conductance',
  'condenser',
  'airspeed',
  'evapotranspiration',
  'reflectivity',
  'sn

In [32]:
getPredFromDesc(model, 'a large house that a rich person lives in', 100)

(['mansion',
  'house',
  'hall',
  'villa',
  'palace',
  'stateroom',
  'apartment',
  'hotel',
  'homestead',
  'domiciled',
  'tenement',
  'lodge',
  'roof',
  'lofting',
  'casa',
  'housing',
  'saloon',
  'cottage',
  'penthouse',
  'residence',
  'roofs',
  'cabinet',
  'householder',
  'home',
  'inn',
  'castle',
  'maisonette',
  'condominium',
  'manor',
  'seraglio',
  'summerhouse',
  'nesting',
  'placed',
  'ibn',
  'townhouse',
  'parkour',
  'cribbed',
  'demesne',
  'palatial',
  'landlord',
  'dwelling',
  'divan',
  'loge',
  'parlor',
  'court',
  'stage',
  'nested',
  'lodging',
  'cabin',
  'chateau',
  'building',
  'nest',
  'chamber',
  'cloakroom',
  'edifice',
  'housemate',
  'rooms',
  'stacked',
  'residency',
  'buttery',
  'guesthouse',
  'studio',
  'rotunda',
  'pension',
  'stack',
  'floor',
  'houseguest',
  'treasury',
  'kitty',
  'hovel',
  'lobbying',
  'pavilion',
  'household',
  'condo',
  'tenancy',
  'roundhouse',
  'stacks',
  'feature

In [33]:
getPredFromDesc(model, 'the opposite of being happy', 100)

(['unhappy',
  'happiness',
  'sadness',
  'misery',
  'vanity',
  'misfortune',
  'bashfulness',
  'complacence',
  'felicitate',
  'evil',
  'nothingness',
  'depression',
  'unhappiness',
  'absurd',
  'emptiness',
  'disappointment',
  'frailty',
  'pessimism',
  'felicity',
  'diffidence',
  'inconvenience',
  'melancholic',
  'egoism',
  'indisposition',
  'irrationality',
  'penury',
  'folly',
  'mischief',
  'languor',
  'blithe',
  'gravity',
  'hopeless',
  'unreality',
  'atheism',
  'joviality',
  'unreasoning',
  'luck',
  'adversity',
  'egotism',
  'discontent',
  'disconsolate',
  'nihilism',
  'insanity',
  'pitiable',
  'unkindly',
  'spleen',
  'transience',
  'arrogance',
  'magnanimity',
  'affliction',
  'gaiety',
  'demerit',
  'infatuation',
  'modesty',
  'ugliness',
  'uneasiness',
  'impotence',
  'disinterest',
  'disaffection',
  'malevolence',
  'sterility',
  'flatness',
  'envy',
  'deficiency',
  'goodness',
  'deadness',
  'regrets',
  'leniency',
  '

In [35]:
getPredFromDesc(model, 'a mammal that lives in water', 100) # decent results

(['amphibian',
  'amphibious',
  'swim',
  'porpoise',
  'plankton',
  'aqueous',
  'swimmer',
  'subaquatic',
  'cetacean',
  'gill',
  'puffer',
  'dipper',
  'swimming',
  'benthos',
  'sucker',
  'monkfish',
  'pelagic',
  'underwater',
  'hydrate',
  'smelt',
  'prawn',
  'submersible',
  'brine',
  'salamander',
  'ocean',
  'limnology',
  'beluga',
  'leviathan',
  'groundfish',
  'laver',
  'leeching',
  'catfish',
  'goby',
  'porgy',
  'neptune',
  'pinniped',
  'diver',
  'hydraulic',
  'lifeguard',
  'hydrosphere',
  'beaver',
  'skater',
  'drowned',
  'bathe',
  'sunfish',
  'dowse',
  'mermaid',
  'mussel',
  'tadpole',
  'dogfish',
  'awash',
  'sheepshead',
  'hardhead',
  'grouper',
  'planktonic',
  'anglerfish',
  'jellyfish',
  'bloodsucker',
  'afloat',
  'vessel',
  'paddling',
  'moray',
  'sea',
  'darter',
  'loach',
  'cataract',
  'shovelnose',
  'mullet',
  'swordfish',
  'laker',
  'butterfish',
  'nautilus',
  'steamer',
  'lake',
  'submarine',
  'supern

In [36]:
getPredFromDesc(model, 'a mammal that lives in ocean', 100) # bad results

(['oceanic',
  'ocean',
  'pelagic',
  'neptune',
  'sea',
  'benthos',
  'oceanography',
  'marine',
  'leviathan',
  'transoceanic',
  'seagoing',
  'nautilus',
  'coaster',
  'seabird',
  'shearwater',
  'beluga',
  'beachcomber',
  'gam',
  'underwater',
  'oceangoing',
  'prawn',
  'narwhal',
  'monkfish',
  'kingfish',
  'bowhead',
  'subaquatic',
  'porpoise',
  'transatlantic',
  'shovelnose',
  'seashell',
  'plankton',
  'halibut',
  'kelp',
  'seaweed',
  'tsunami',
  'cetacean',
  'seahorse',
  'hake',
  'seaboard',
  'mussel',
  'archipelago',
  'blackfish',
  'gopher',
  'brine',
  'reefer',
  'jason',
  'iceberg',
  'main',
  'anchovy',
  'buoy',
  'submarine',
  'seafaring',
  'hardhead',
  'brachiopod',
  'manatee',
  'atlantic',
  'mermaid',
  'baleen',
  'groundfish',
  'planktonic',
  'orion',
  'amphibian',
  'seaward',
  'offshore',
  'oversea',
  'cancer',
  'beagle',
  'chiton',
  'tide',
  'dogfish',
  'serval',
  'hydrosphere',
  'moray',
  'lobster',
  'crust

In [63]:
getPredFromDesc(model, 'a person who is very knowledgeable about many subjects', 100)

(['scholar',
  'pedant',
  'student',
  'expert',
  'professor',
  'intellectual',
  'literate',
  'adept',
  'polyglot',
  'bookworm',
  'lector',
  'educator',
  'thinker',
  'monitor',
  'teacher',
  'master',
  'virtuoso',
  'highbrow',
  'reader',
  'proficient',
  'mastering',
  'doctoring',
  'knowledgeable',
  'skeptic',
  'amateur',
  'connoisseur',
  'pedagogue',
  'technocrat',
  'doctored',
  'savant',
  'scientist',
  'bookman',
  'philosopher',
  'erudition',
  'conversant',
  'doctor',
  'gnostic',
  'cpa',
  'appreciator',
  'humanist',
  'collector',
  'pundit',
  'learning',
  'lettered',
  'mandarin',
  'science',
  'informant',
  'generalist',
  'artist',
  'scholastic',
  'teller',
  'learner',
  'informer',
  'tea',
  'technologist',
  'inquisitor',
  'educated',
  'lore',
  'fellow',
  'orator',
  'university',
  'specialist',
  'technician',
  'sage',
  'brains',
  'stargazer',
  'scholarship',
  'intellect',
  'study',
  'exegete',
  'empiric',
  'observer',
  

In [24]:
model_name = type(model).__name__
filename = f'../trained_models/{model_name} Epoch {epoch+1} at {datetime.datetime.now()}'.replace(' ', '_')
with open(filename, 'wb+') as f:
    torch.save({'state_dict': model.state_dict()}, f)

In [35]:
x, y, yvecs = next(iter(train_loader))

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 202, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "../code/dataset.py", line 254, in collate_fn
    Yvecs = self.embeddings.get_vecs(Ys)
  File "../code/dataset.py", line 55, in get_vecs
    vecs = [self.embeddings[self.stoi[t]] for t in tokens]
  File "../code/dataset.py", line 55, in <listcomp>
    vecs = [self.embeddings[self.stoi[t]] for t in tokens]
RuntimeError: CUDA error: initialization error


In [32]:
word2vec.embeddings.shape

torch.Size([75100, 300])