In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import random
import csv
import json
import os
import sys

In [11]:
with open('../Data/train.csv', 'r') as f:
    reader = csv.reader(f)
    train_data = list(reader)[1:]
  
train_data

[['CNC[C@H](c1ccc(Cl)c(Cl)c1)[C@@H](O)c1cccc(N)c1', '0'],
 ['C[C@@H]1Cn2ncc(C3CCN(S(C)(=O)=O)CC3)c2CN1c1ccnc2[nH]ccc12', '0'],
 ['CNC(=O)c1c(NCC2CCC3(CCCC3)CC2)nc(C#N)nc1OCC1CCN(C)CC1', '1'],
 ['Cc1cc(CNc2nc(N)nc3ccn(Cc4ccccn4)c23)no1', '1'],
 ['C[C@@]1(c2cc(CNCCC(F)(F)F)c(F)cc2F)CCSC(N)=N1', '1'],
 ['NC(c1ccccc1)C(O)(c1cccnc1)c1cccnc1', '0'],
 ['CCCS(=O)(=O)N1CC(CC#N)(n2cc(-c3ccnc4[nH]ccc34)cn2)C1', '0'],
 ['C[C@@H]1COCCN1c1cc(C2(S(C)(=O)=O)CC2)nc(-c2cccc3[nH]ccc23)n1', '0'],
 ['CN1CCC(COCc2cc(Cl)cc(C(F)(F)F)n2)(c2ccccc2)CC1', '1'],
 ['COC1COCCC1N[C@@H]1C[C@H]2C[C@H](F)C[C@@]2(C(=O)N2CCc3ncc(C(F)(F)F)cc3C2)C1',
  '0'],
 ['Cc1nccn1Cc1cn(-c2ccc(Cl)cc2Cl)nn1', '1'],
 ['CNC(=O)C(c1cc(-c2ncnc3cc(N4CCOCC4)ccc23)c(F)cc1Cl)c1nccnc1C', '0'],
 ['CCC(CN)=C1CN(c2nc(Oc3cnc(CO)nc3)nc3[nH]c4c(NC)cc(F)cc4c23)C1', '0'],
 ['Cc1ccc(NC(=O)c2cc(C(F)(F)F)cnn2)cc1-c1ccc2c(c1)OC1(CCOCC1)CC(=O)N2', '1'],
 ['C[C@@H]1CN(C(=O)OC2(C(F)(F)F)COC2)CCN1c1ncc(OCc2ccc(S(C)(=O)=O)cc2F)cn1',
  '1'],
 ['NC1(C(=O)NC(CCCN2C

In [12]:
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from src.tokenizer import GPT4Tokenizer
tokenizer = GPT4Tokenizer()
tokenizer.load_vocab('vocab.json')
tokenizer.vocab[512] = '<none>'


In [13]:
print(tokenizer.encode("NC(=O)c1ccc(Cl)cc1)[C@H]1[C@@H]2C[C@@H](n3cnc4cccc(F)c43)C[C@@H]21CNC[C@H](c1ccc2ccccc2c1)[C@H](O)c1cncc(OC)c1COc1cnc(NCCNCCCF)cc1[C@@H]1c2[nH]c3ccccc3c2C[C@@H](C)N1CC(F)(F)FC[C@@]1(c2cc(NC(=O)c3ccc(C#N)cn3)ccc2F)Cn2cc(C#N)n"))


[375, 269, 335, 324, 487, 438, 347, 413, 386, 256, 270, 288, 356, 488, 272, 289, 278, 352, 360, 325, 278, 302, 338, 256, 412, 41, 308, 79, 338, 313, 257, 78, 390, 498, 407, 485, 358, 400, 408, 457, 371, 336, 49, 385, 397, 284, 280, 454, 273, 361, 439, 67, 420, 259, 344, 296]


In [14]:
encoded_data = []
for row in train_data:
    text = row[0]  # SMILES string
    label = int(row[1])  # Label
    tokens = tokenizer.encode(text)
    encoded_data.append((torch.tensor(tokens), label))

In [15]:
# pad all sequences to the same length
padded_data = []
for tokens, label in encoded_data:
    padded_tokens = F.pad(tokens, (0, 64 - len(tokens)), value=512)
    padded_data.append((padded_tokens, torch.tensor(label)))

In [None]:


class ClassificationTransformer(nn.Module):
    def __init__(self, vocab_size, emb_dim=128, num_heads=8, num_layers=6, max_len=64, num_classes=2, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        self.pos_embedding = nn.Embedding(max_len, emb_dim)
        
        encoder_layers = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=num_heads, dropout=dropout)
        self.transformer = nn.TransformerEncoder(encoder_layers, num_layers=num_layers)
        
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_dim))
        self.fc = nn.Linear(emb_dim, num_classes)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        batch_size, seq_length = x.size()
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = self.embedding(x) + self.pos_embedding(torch.arange(seq_length, device=x.device).unsqueeze(0))
        x = torch.cat((cls_tokens, x), dim=1)
        x = x.permute(1, 0, 2)  # Transformer expects (seq_length, batch_size, emb_dim)
        x = self.transformer(x)
        x = x[0]  # CLS token
        x = self.dropout(x)
        logits = self.fc(x)
        return logits

class TokenDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data  # List of token tensors
        self.labels = labels  # List of labels
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

def get_batches(dataset, batch_size):
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    for batch in loader:
        tokens, labels = batch
        yield tokens.to(device), labels.to(device)

def train_model(model, train_dataset, val_dataset, epochs=10, batch_size=32, lr=1e-3):
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for tokens, labels in train_loader:
            tokens, labels = tokens.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(tokens)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        avg_loss = total_loss / len(train_loader)
        val_accuracy = evaluate(model, val_loader)
        print(f'Epoch {epoch+1}, Loss: {avg_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}')
    
    torch.save(model.state_dict(), 'classification_transformer.pth')

def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for tokens, labels in loader:
            tokens, labels = tokens.to(device), labels.to(device)
            outputs = model(tokens)
            _, preds = torch.max(outputs, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return correct / total

# Example usage
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    vocab_size = 513
    tokens = [item[0] for item in padded_data]
    labels = [item[1] for item in padded_data]
    
    dataset = TokenDataset(tokens, labels)
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    
    model = ClassificationTransformer(vocab_size=vocab_size, max_len=64).to(device)
    train_model(model, train_dataset, val_dataset, epochs=10, batch_size=32, lr=1e-3)

TypeError: list indices must be integers or slices, not tuple