In [1]:
import utils
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import pickle
import torch
from typing import List, Tuple

In [2]:
class SkipGram(torch.nn.Module):
    def __init__(self, vocab_size: int, emb_size: int):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_size)
        self.out = nn.Linear(emb_size, vocab_size)
        self.vocab_size = vocab_size

    def forward(self, word_id: int):
        x = self.emb(word_id)
        x = self.out(x)
        return F.log_softmax(x, dim=-1)
    
    def get_emb(self, word_id: int):
        return self.emb(word_id)


def batched_data(processed_data: List[List[str]], batch_size: int, wdict: dict):
    batch_x = []
    batch_y = []
    for doc in processed_data:
        for i in range(len(doc)):
            word = wdict[doc[i]]
            for j in range(4):
                idx = -2 + j
                if idx >= 0:
                    idx += 1
                if i + idx >= 0 and i + idx < len(doc):
                    context = wdict[doc[i + idx]]
                    if len(batch_x) < batch_size:
                        batch_x.append(word)
                        batch_y.append(context)
                    else:
                        yield torch.tensor(batch_x), torch.tensor(batch_y)
                        batch_x = []
                        batch_y = []

def create_wdict(processed_data: List[List[str]]) -> dict:
    wdict = {}
    for doc in processed_data:
        for word in doc:
            if word not in wdict:
                wdict[word] = len(wdict)
    return wdict

In [3]:
processed_data = utils.import_smiles('data/HIV.csv', skiprow=True)
wdict = create_wdict(processed_data)

In [4]:
batch_size = 1000
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = SkipGram(len(wdict), 300).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 


for epoch in range(10):
    with tqdm() as pbar:
        for word, context in batched_data(processed_data, batch_size, wdict):
            word = word.to(device)
            context = context.to(device)
            optimizer.zero_grad()
            out = model(word)
            loss = F.nll_loss(out, context)

            loss.backward()
            optimizer.step()
            pbar.update(1)

cuda


6872it [00:14, 474.91it/s]
6872it [00:12, 555.31it/s]
6872it [00:12, 537.42it/s]
6872it [00:13, 515.50it/s]
6872it [00:12, 533.55it/s]
6872it [00:12, 531.97it/s]
6872it [00:12, 539.14it/s]
6872it [00:13, 518.29it/s]
6872it [00:12, 533.07it/s]
6872it [00:12, 554.24it/s]


In [5]:
torch.save(model.state_dict(), 'models/emb.pt')
with open('models/wdict.p', 'wb') as f:
    pickle.dump(wdict, f)