In [1]:
import torch
import numpy as np

from torch import nn
from torch.utils.data import DataLoader
from torch.cuda.amp.autocast_mode import autocast
from torch.cuda.amp.grad_scaler import GradScaler
from torch.utils.tensorboard.writer import SummaryWriter


In [2]:
#Embeddings and Trainer

class Embers(nn.Module):

    def __init__(self, vocab_size, emb_complexity):
        torch.manual_seed(1024)
        super(Embers, self).__init__()
        self.focus = nn.Embedding(vocab_size, emb_complexity)
        self.contrast = nn.Embedding(vocab_size, emb_complexity)
        nn.init.xavier_uniform_(self.focus.weight)
        nn.init.xavier_uniform_(self.contrast.weight)

    def forward(self, inp):
        x, xc = torch.hsplit(inp, (1,))
        x = self.focus(x)
        xc = self.contrast(xc)
        return torch.bmm(x, xc.permute(0, 2, 1))


class Embedding_Trainer:

    def __init__(self, dataset, data_group, batch_size, vocab_size, embedding_complexity, previous_batches):
        self.previous_batches = previous_batches
        self.data_group = data_group
        if data_group[1] == 1:
            last_group = (data_group[0]-1, 2)
        else:
            last_group = (data_group[0], 1)
        self.batch_size = batch_size
        self.data_size = len(dataset)
        self.tdata = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
        self.emb = Embers(vocab_size, embedding_complexity)
        if self.data_group != (1, 1):
            self.emb.load_state_dict(torch.load(f'models/embeddings{last_group[0]}-{last_group[1]}'))
        self.tblog = SummaryWriter(log_dir='C:/Users/BBA/Coding/tblogs/word_embeddings/v1.0')
        self.emb = self.emb.to('cuda')
        self.scaler = GradScaler()
        self.lossfn = nn.BCEWithLogitsLoss(reduction='none')
        self.optimizer = torch.optim.NAdam(self.emb.parameters())
        self.ltensor = torch.tensor([[[1, 0, 1, 0, 1, 0, 1, 0]]] * self.batch_size).type(torch.FloatTensor)

    def train(self):
        self.emb.train()
        loss_total = 0
        batches = len(self.tdata) - 1
        for i, datum in enumerate(self.tdata):
            if i == batches:
                self.ltensor = self.ltensor = torch.tensor([[[1, 0, 1, 0, 1, 0, 1, 0]]] * (self.data_size - (batches * self.batch_size))).type(torch.FloatTensor)
            with autocast(enabled=True):
                datum = datum.to('cuda')
                outp = self.emb(datum)
                loss = self.lossfn(outp.type(torch.FloatTensor), self.ltensor).sum().to('cuda')
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()
            self.optimizer.zero_grad()
            loss_total += loss
            self.previous_batches += 1
            self.tblog.add_scalar('Train Loss / Batch', loss_total / (i + 1) / self.batch_size, self.previous_batches)
            if i > 0 and i % 2048 == 0:
                print(f'Loss: {loss_total / i / self.batch_size}')
        torch.save(self.emb.state_dict(), f'models/embeddings{self.data_group[0]}-{self.data_group[1]}')
        torch.cuda.empty_cache()



In [None]:
#Training Sequence for All Data Sets
overall_batches = 0

for x in range(1, 5):
    for y in range(1, 3):
        data_group = f'{x}-{y}'
        with open(f'D:\\dstore\\nlp\\w2v\\train{x}-{y}', 'rb') as f:
            trainx = np.load(f)
        total_batches = int(len(trainx) / 1024) + 1
        emtrain = Embedding_Trainer(trainx, (x, y), 1024, 155000, 128, overall_batches)
        emtrain.train()
        overall_batches += total_batches
        del emtrain
        del trainx


In [192]:
#Results and Word Similarity

with open(r'D:\dstore\nlp\w2v\word_list_i', 'rb') as f:
    wdct = np.load(f, allow_pickle=True)
idct = {x[1]: x[0] for x in wdct.items()}

emb = Embers(155000, 128)
emb.load_state_dict(torch.load('models/embeddings4-2'))
emb = emb.get_submodule('focus')
emb.eval()
allw = emb.weight.data

def closest_words(word):
    targetw = emb(torch.tensor(wdct[word]))
    scores = nn.functional.cosine_similarity(targetw, allw)
    indices = torch.topk(scores, 8)
    for i, _ in enumerate(indices[1][1:]):
        print(f'score: {float(indices[0][i+1]):.3f} ~ {idct[int(indices[1][i+1])]}')

In [193]:
closest_words('lions')

score: 0.636 ~ eagles
score: 0.572 ~ bisons
score: 0.551 ~ rhinos
score: 0.550 ~ dolphins
score: 0.538 ~ bears
score: 0.536 ~ panthers
score: 0.530 ~ tigers


In [194]:
closest_words('tigers')

score: 0.572 ~ wildcats
score: 0.552 ~ bisons
score: 0.549 ~ cougars
score: 0.530 ~ lions
score: 0.529 ~ eagles
score: 0.521 ~ jaguars
score: 0.516 ~ league


In [195]:
closest_words('bears')

score: 0.538 ~ lions
score: 0.532 ~ badgers
score: 0.513 ~ devils
score: 0.511 ~ beavers
score: 0.494 ~ toothed
score: 0.489 ~ notched
score: 0.486 ~ cougars


In [196]:
closest_words('lion')

score: 0.505 ~ bauble
score: 0.501 ~ hound
score: 0.491 ~ mighty
score: 0.486 ~ boar
score: 0.483 ~ legend
score: 0.483 ~ eagle
score: 0.478 ~ statant


In [197]:
closest_words('tiger')

score: 0.438 ~ jungle
score: 0.438 ~ mule
score: 0.419 ~ cheetah
score: 0.412 ~ kota
score: 0.412 ~ cheetal
score: 0.411 ~ snakes
score: 0.410 ~ bekah


In [198]:
closest_words('bear')

score: 0.563 ~ wolf
score: 0.520 ~ hyenas
score: 0.518 ~ hoofed
score: 0.514 ~ deer
score: 0.511 ~ captive
score: 0.501 ~ nyctereutes
score: 0.497 ~ hunting
