In [2]:
import os.path
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ["PYTORCH_USE_CUDA_DSA"] = "1"


import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import random
import re
import torch
from tqdm import tqdm
from sklearn.preprocessing import MinMaxScaler

from sklearn.utils import resample
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split

from torch.utils.data import DataLoader

if torch.cuda.is_available():
    device = torch.device("cuda")
    print('There are %d GPU(s) available.' % torch.cuda.device_count())
    print('We will use the GPU:', torch.cuda.get_device_name(0))

else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")


There are 1 GPU(s) available.
We will use the GPU: Tesla V100S-PCIE-32GB


In [3]:
from my_model import CustomModel, PT5_classification_model, create_dataset, load_model_

[2024-05-07 01:54:27,373] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [4]:
from utilites import MyLabelEncoder, add_spaces, balance_majority

In [5]:
from EpiNet import EpiTEINet

In [6]:
from TCRpeg import TCRpeg

In [7]:
def random_recombination(df, column_name, epitope_dist, tcr_dist, ratio):
    unique_epitopes = df['antigen.epitope'].unique()
    unique_tcrs = df[column_name].unique()
    conversion_df = df[['antigen.epitope', column_name]]
    positive_pairs = set([tuple(x) for x in conversion_df.to_numpy()])

    # We want to weight the tcr choice by frequency in data
    epitope_freq_array = [epitope_dist[peptide] / len(df) for peptide in unique_epitopes]
    tcr_freq_array = [tcr_dist[tcr] / len(df) for tcr in unique_tcrs]
    
    neg_pairs = set()
    for pep in unique_epitopes:
        i = 0
        pairs_to_generate = round(epitope_dist[pep] * ratio)
        while i < pairs_to_generate:
            tcr = np.random.choice(unique_tcrs, p=tcr_freq_array)
            pair = (pep, tcr)
            if pair not in positive_pairs and pair not in neg_pairs:
                neg_pairs.add(pair)
                i += 1
            
    negative_data = pd.DataFrame(neg_pairs, columns = ['antigen.epitope', column_name])
    negative_data = negative_data.assign(affinity=0)
    return negative_data

def conc_tensors(tensor1, tensor2):
    tensor1_squeezed = tensor1.squeeze(1)
    combined_tensor = torch.cat((tensor1_squeezed, tensor2), dim=0)
    return combined_tensor

In [8]:
model_epi = TCRpeg(hidden_size=768,num_layers = 3,load_data=False,embedding_path='aa_emb_epi.txt')
model_epi.create_model(load=True, path = 'encoder_epi.pth')

In [9]:
model = EpiTEINet(en_epi = model_epi,cat_size=768*2,normalize=True,weight_decay = 0).to(device)

In [10]:
# epi_emb = 
model.get_emb(['KLGGALQAK'], model_epi, model_epi.model).shape

torch.Size([1, 768])

In [11]:
vdjdb = pd.read_csv('../data/vdjdb_full.txt', sep='\t', low_memory=False)

In [12]:
le = MyLabelEncoder()
groups = vdjdb['antigen.epitope'].value_counts().index

le.fit(groups)
groups
le.transform(groups)

array([   0,    1,    2, ..., 1166, 1167, 1168])

In [13]:
vdjb_short = vdjdb[['cdr3.alpha', 'cdr3.beta', 'antigen.epitope']]

In [14]:
vdjb_short['antigen.epitope'].value_counts()[vdjb_short['antigen.epitope'].value_counts()>500].shape[0]

20

In [15]:
vdjb_short['antigen.epitope']= le.transform(vdjb_short['antigen.epitope'])
vdjb_short

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  vdjb_short['antigen.epitope']= le.transform(vdjb_short['antigen.epitope'])


Unnamed: 0,cdr3.alpha,cdr3.beta,antigen.epitope
0,CIVRAPGRADMRF,CASSYLPGQGDHYSNQPQHF,38
1,,CASSFEAGQGFFSNQPQHF,38
2,CAVPSGAGSYQLTF,CASSFEPGQGFYSNQPQHF,38
3,CAVKASGSRLT,CASSYEPGQVSHYSNQPQHF,38
4,CAYRPPGTYKYIF,CASSALASLNEQFF,38
...,...,...,...
62172,CMDEGGSNYKLTF,CASSVRSTDTQYF,130
62173,CSLYNNNDMRF,CASSLRYTDTQYF,130
62174,CALSTDSWGKLQF,CASSPGQGGDNEQFF,254
62175,CAPQGATNKLIF,CASSLGAGGQETQYF,254


In [16]:
vdjb_short = vdjb_short[vdjb_short['antigen.epitope']<=20]

In [17]:
vdjb_short['antigen.epitope']= le.inverse_transform(vdjb_short['antigen.epitope'])
vdjb_short

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  vdjb_short['antigen.epitope']= le.inverse_transform(vdjb_short['antigen.epitope'])


Unnamed: 0,cdr3.alpha,cdr3.beta,antigen.epitope
385,,CASSPQTGTGGYGYTF,NLVPMVATV
386,,CASSPQTGTGGYGYTF,NLVPMVATV
387,,CASSPLFGTSGGETYYF,NLVPMVATV
388,,CASSPQTGTGGYGYTF,NLVPMVATV
389,,CASSPQTGASYGYTF,NLVPMVATV
...,...,...,...
61810,CAGLNYGGSQGNLIF,CASSWRQGGSIRESYTF,TFEYVSQPFLMDLE
61811,,CASSLSSGWPYGYTF,TFEYVSQPFLMDLE
61812,CAALNYGGSQGNLIF,CASSDRGTGLNGYTF,TFEYVSQPFLMDLE
61813,CAGLNYGGSQGNLIF,CASGPGGMTEAFF,TFEYVSQPFLMDLE


In [18]:
vdjb_short = vdjb_short.fillna('')
vdjb_short['cdr3aa'] = vdjb_short['cdr3.alpha'] + vdjb_short['cdr3.beta'].apply(lambda s: s[::-1])
vdjb_cdr = vdjb_short[['cdr3aa', 'antigen.epitope']]
# vdjb_cdr = vdjb_short[['cdr3.beta', 'antigen.epitope']]


In [19]:
vdjb_cdr = balance_majority(vdjb_cdr, 'antigen.epitope', max_count=500)

In [20]:
vdjb_cdr.dropna(inplace=True)

In [21]:
vdjb_cdr['antigen.epitope'].value_counts()

CTPYDINQM               500
FRDYVDRFYKTLRAEQASQE    500
TFEYVSQPFLMDLE          500
LLWNGPMAV               500
KRWIILGLNK              500
NEGVKAAW                500
LLLGIGILV               500
SPRWYFYYL               500
SSYRRPVGI               500
IVTDFSVIK               500
PKYVKQNTLKLAT           500
KLGGALQAK               500
TTPESANL                500
GLCTLVAML               500
RAKFKQLL                500
YLQPRTFLL               500
AVFDRKSDAK              500
ELAGIGILTV              500
NLVPMVATV               500
GILGFVFTL               500
TTDPSFLGRY              448
Name: antigen.epitope, dtype: int64

In [22]:
vdjb_cdr['affinity'] = 1

In [23]:
vdjb_cdr

Unnamed: 0,cdr3aa,antigen.epitope,affinity
50061,CAASAFISNTGKLIFFYQTEIGTGDDSSAC,TTDPSFLGRY,1
50077,CALVPNARLMFFYQTDGAWRDTSAC,TTDPSFLGRY,1
50078,CAVLGVYNQGGKLIFFTYGYGGLPFSSAC,TTDPSFLGRY,1
50079,CAPAVYNFNKFYFFHLPSNDLSSAC,TTDPSFLGRY,1
50081,CVGGGADGLTFFYQEGSGGAGLSSAC,TTDPSFLGRY,1
...,...,...,...
18310,CALSGLGYGNKLVFFFAETGGRGLGTSAC,PKYVKQNTLKLAT,1
52484,FTYGMPEGQLLSSAC,PKYVKQNTLKLAT,1
18258,CAENMRGSNYKLTFFFAETNSQRGDLSSAC,PKYVKQNTLKLAT,1
52610,FYQTDVAGGGNSAC,PKYVKQNTLKLAT,1


In [24]:
df = vdjb_cdr.copy()

ratio = 0.8
tcr_distribution = df['cdr3aa'].value_counts()
epitope_distribution = df['antigen.epitope'].value_counts()
negative_data = random_recombination(df, 'cdr3aa', epitope_distribution, tcr_distribution, ratio)
negative_data


Unnamed: 0,antigen.epitope,cdr3aa,affinity
0,ELAGIGILTV,FYQPQNSVRDLSSAC,0
1,AVFDRKSDAK,CAGQNYGGSQGNLIF,0
2,NEGVKAAW,FFQENYSRDRSSAC,0
3,TTPESANL,CVVNVMDDMRFFFLEGTNRDPSSAC,0
4,NEGVKAAW,CLVGDIQGGGGKLIFFFLEGTDGEVSC,0
...,...,...,...
8353,ELAGIGILTV,FYQPQNSVRNLSSAC,0
8354,CTPYDINQM,CAYRGYSGGGADGLTFFTYGYYAQKSSAC,0
8355,KRWIILGLNK,FFLKENTLGRGELSSAC,0
8356,NEGVKAAW,FYQPQNSNGNLSSAC,0


In [25]:
full_data = pd.concat([df, negative_data])
full_data = full_data.reset_index(drop=True)
full_data

Unnamed: 0,cdr3aa,antigen.epitope,affinity
0,CAASAFISNTGKLIFFYQTEIGTGDDSSAC,TTDPSFLGRY,1
1,CALVPNARLMFFYQTDGAWRDTSAC,TTDPSFLGRY,1
2,CAVLGVYNQGGKLIFFTYGYGGLPFSSAC,TTDPSFLGRY,1
3,CAPAVYNFNKFYFFHLPSNDLSSAC,TTDPSFLGRY,1
4,CVGGGADGLTFFYQEGSGGAGLSSAC,TTDPSFLGRY,1
...,...,...,...
18801,FYQPQNSVRNLSSAC,ELAGIGILTV,0
18802,CAYRGYSGGGADGLTFFTYGYYAQKSSAC,CTPYDINQM,0
18803,FFLKENTLGRGELSSAC,KRWIILGLNK,0
18804,FYQPQNSNGNLSSAC,NEGVKAAW,0


In [26]:
N_LABELS = full_data['antigen.epitope'].nunique()

tcr_mod = load_model_('../models_ft/tcr_epit.pth', mod_type='TCR-bert', num_labels=N_LABELS)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at wukevin/tcr-bert and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([45, 768]) in the checkpoint and torch.Size([21, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([45]) in the checkpoint and torch.Size([21]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [27]:
tcr_mod.maximun_len = 45
tcr_mod.to(device)

CustomModel(
  (model): BertForSequenceClassification(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(26, 768, padding_idx=21)
        (position_embeddings): Embedding(64, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=

In [28]:
lbls = []
final_tens = []#torch.empty(0, 0, 768)
tcr_mod = tcr_mod.to(device)
model = model.to(device)

for seq in tqdm(full_data.index):
    # print(seq)
    seq, epi, lab = full_data.loc[seq]
    en_dict = tcr_mod.tokenizer.encode_plus(seq, add_special_tokens = True, 
                                      max_length = 45, pad_to_max_length = True, return_attention_mask = True, return_tensors = 'pt')
    
    input_ids_test, att_mask_test = en_dict['input_ids'], en_dict['attention_mask']
   
    input_ids_test = input_ids_test.to(device)
    att_mask_test= att_mask_test.to(device)
   
    with torch.no_grad():
        outputs = tcr_mod.model(input_ids_test, att_mask_test, output_hidden_states=True)
        
    tcr_emb = outputs.hidden_states
    tcr_emb = torch.stack(list(tcr_emb), dim=0)
    
    
    tcr_emb = torch.mean(tcr_emb, dim = 2)
    tcr_emb = torch.mean(tcr_emb, dim = 0)
    
    epi_emb = model.get_emb([epi], model_epi, model_epi.model)
    # print(epi_emb.shape)
    
    combined_tensor = conc_tensors(tcr_emb, epi_emb)
        
    final_tens.append(combined_tensor)
    lbls.append(lab)

  0%|          | 0/18806 [00:00<?, ?it/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
100%|██████████| 18806/18806 [03:37<00:00, 86.46it/s]


In [29]:
lbls_tens = torch.tensor(lbls, dtype=torch.float32).to(device)

In [30]:
final_tens[0].shape

torch.Size([2, 768])

In [31]:
# final_tens_2 = [tens[0]*tens[1] for tens in final_tens]
# final_tens_2[0]

In [32]:
# final_tens_np = [tens.to('cpu').detach().numpy() for tens in final_tens]

In [33]:

# df = pd.DataFrame({'embedding': final_tens_np, 'label': lbls})

In [34]:
# concatenated_tensor = torch.stack(final_tens, dim=0)

# concatenated_tensor = torch.stack(final_tens, dim=0)#.view(20896, -1)

concatenated_tensor = torch.stack(final_tens, dim=0)#.view(20896, -1)
concatenated_tensor.shape

torch.Size([18806, 2, 768])

In [35]:
from torch.utils.data import TensorDataset, random_split
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

lbls_tens = torch.tensor(lbls, dtype=torch.float32).to(device)
dataset = TensorDataset(concatenated_tensor, lbls_tens)

In [36]:
# train_size = int(0.6 * len(lbls))
# val_size = int(0.2 * len(lbls))
# test_size = len(lbls) - train_size - val_size

In [37]:
# train_df, val_df, test_df = random_split(dataset, [train_size, val_size, test_size])


In [38]:
# batch_size = 64

# train_loader = DataLoader(train_df,
#                 batch_size = batch_size, shuffle=True)

# # val_loader = DataLoader(val_df,
# #                 sampler = SequentialSampler(val_df),
# #                 batch_size = batch_size)
    
# test_loader = DataLoader(test_df,
#                 # sampler = RandomSampler(test_df),
#                 batch_size = batch_size, shuffle=True)


In [39]:
dataset[0][1]

tensor(1., device='cuda:0')

In [40]:
# torch.tensor(1).shape

In [41]:
device = 'cpu'

In [42]:
# import torch.nn as nn
# import torch.optim as optim


# class SimpleClassifier(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.fc1 = nn.Linear(768 * 14, 256)
#         self.relu = nn.ReLU()
#         self.fc2 = nn.Linear(256, 1)
#         self.sigmoid = nn.Sigmoid()

#     def forward(self, x):
#         x = self.fc1(x)
#         x = self.relu(x)
#         x = self.fc2(x)
#         x = self.sigmoid(x)
#         return x


# input_size = 768 * 14  
# hidden_size = 128  
# output_size = 1  
# clf_model = SimpleClassifier()

# num_epochs = 4
# criterion = nn.CrossEntropyLoss()

# optimizer = optim.Adam(clf_model.parameters(), lr = 0.001)
# total_steps = len(train_loader) * num_epochs


In [43]:
# clf_model.to(device)

# for epoch in range(num_epochs):
#     clf_model.train()
#     running_loss = 0.0
#     for step, batch in enumerate(train_loader):
#         optimizer.zero_grad()  
#         batch = tuple(t.to(device) for t in batch)
#         inputs, labels = batch

#         inputs = inputs.view(-1, input_size)
#         # print(inputs.shape)
#         outputs = clf_model(inputs)
#         # print(outputs)
#         outputs = outputs.squeeze()#.float()
#         print(outputs)
#         # optimizer.zero_grad()
#         loss = criterion(outputs, labels)
        
#         running_loss += loss.item()
#         print(loss)
#         # optimizer.zero_grad()
#         loss.backward()  # Calculate gradients
#         torch.nn.utils.clip_grad_norm_(clf_model.parameters(), 1.0)
#         optimizer.step()  # Update weights
    
#     #Validation
#     print("")
#     print("Running Validation...")
    
#     clf_model.eval()
#     correct = 0
#     total = 0
    
#     for batch in val_loader:
#         batch = tuple(t.to(device) for t in batch)
#         inputs, labels = batch
#         with torch.no_grad():
#             inputs = inputs.view(-1, input_size)
#             outputs = clf_model(inputs)
#             predicted = torch.round(outputs)
#             total += labels.size(0)
#             correct += (predicted == labels).sum().item()
            
#     print(f"Accuracy on validation set: {100 * correct / total}%")


In [44]:
concatenated_tensor.shape
X = concatenated_tensor
y = lbls_tens
X.shape

torch.Size([18806, 2, 768])

In [45]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split

# Define your neural network architecture
class SimpleClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleClassifier, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        return x

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)


In [46]:

X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32)


train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


  X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
  X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
  y_train_tensor = torch.tensor(y_train, dtype=torch.float32)
  y_test_tensor = torch.tensor(y_test, dtype=torch.float32)


In [54]:
class FCNet(nn.Module):
    def __init__(self):
        super(FCNet, self).__init__()
        self.fc1 = nn.Linear(768 * 2, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 2)
        self.relu = nn.ReLU()
        self.softmax = nn.Sigmoid()  # Softmax for multi-class classification

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        x = self.softmax(x)
        return x

In [58]:
def get_accuracy(model, dataloader, device='cuda'):
    """
    model - обученная нейронная сеть
    dataloader - даталоадер, на котором вы хотите посчитать accuracy
    """
    model = model.to(device)
    correct = 0
    total = 0
    with torch.no_grad(): # Тензоры внутри этого блока будут иметь requires_grad=False
        for images, labels in dataloader:
            images=images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            # print(labels)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total

    return accuracy

In [59]:
fc_net = FCNet().to(device)


loss_function = nn.CrossEntropyLoss()
lr = 0.01 
optimizer = torch.optim.Adam(fc_net.parameters(), lr=lr)

In [62]:
n_epochs = 10
loss_history = []
fc_net.to(device)


optimizer = torch.optim.Adam(fc_net.parameters(), lr=3e-10)

for epoch in range(n_epochs):
    fc_net.train()
    curr_loss = 0.0
    for batch in train_loader:
        # batch.shape
        im, lab = tuple(t.to(device) for t in batch)
        fc_net = fc_net.to(device)
        # print(im)
        optimizer.zero_grad()
        outputs = fc_net(im)
        # print(outputs)
        loss = loss_function(outputs, lab.long())

        loss.backward()
        optimizer.step()

        curr_loss += loss.item()

    tr_loss = curr_loss / len(train_loader)
    tr_accuracy = get_accuracy(fc_net, train_loader)
    test_accuracy = get_accuracy(fc_net, test_loader)

    print(f"Epoch {epoch+1}/{n_epochs}, "
          f"Train Loss: {tr_loss:.4f}, "
          f"Train Accuracy: {tr_accuracy:.4f}, "
          f"Test Accuracy: {test_accuracy:.4f}")

print("Training complited!")

Epoch 1/10, Train Loss: 0.6872, Train Accuracy: 0.5530, Test Accuracy: 0.5659
Epoch 2/10, Train Loss: 0.6874, Train Accuracy: 0.5530, Test Accuracy: 0.5659
Epoch 3/10, Train Loss: 0.6874, Train Accuracy: 0.5530, Test Accuracy: 0.5659
Epoch 4/10, Train Loss: 0.6873, Train Accuracy: 0.5530, Test Accuracy: 0.5659
Epoch 5/10, Train Loss: 0.6872, Train Accuracy: 0.5530, Test Accuracy: 0.5659
Epoch 6/10, Train Loss: 0.6875, Train Accuracy: 0.5530, Test Accuracy: 0.5659
Epoch 7/10, Train Loss: 0.6875, Train Accuracy: 0.5530, Test Accuracy: 0.5659
Epoch 8/10, Train Loss: 0.6872, Train Accuracy: 0.5530, Test Accuracy: 0.5659
Epoch 9/10, Train Loss: 0.6874, Train Accuracy: 0.5530, Test Accuracy: 0.5659
Epoch 10/10, Train Loss: 0.6875, Train Accuracy: 0.5530, Test Accuracy: 0.5659
Training complited!


In [66]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for tens, labels in test_loader:
        outputs = fc_net(tens)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        # print(predicted)
        # print(labels)
        correct += (predicted == labels).sum().item()
    
accuracy = correct / total

In [67]:
print(f"Accuracy on test set: {100*correct / total}%")


Accuracy on test set: 56.59223817118554%


In [70]:
device = 'cuda'



In [79]:
input_size = 768 *2 
hidden_size = 128  
output_size = 2  
model = SimpleClassifier(input_size, hidden_size, output_size).to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=3e-5)

model = model.to(device)
num_epochs = 20
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        # print(labels)
        inputs = inputs.view(-1, input_size).to(device)  
        optimizer.zero_grad()
        outputs = model(inputs)
        # print(torch.argmax(outputs, dim =1).float())
        # print(labels)
        loss = criterion(torch.argmax(outputs, dim =1).squeeze().float(), labels)
        loss.requires_grad = True
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader)}")

# Evaluate the model



Epoch [1/20], Loss: 51.924097664543524
Epoch [2/20], Loss: 51.877653927813164
Epoch [3/20], Loss: 51.78476645435244
Epoch [4/20], Loss: 51.73832271762208
Epoch [5/20], Loss: 51.877653927813164
Epoch [6/20], Loss: 51.73832271762208
Epoch [7/20], Loss: 51.8312101910828
Epoch [8/20], Loss: 51.877653927813164
Epoch [9/20], Loss: 51.877653927813164
Epoch [10/20], Loss: 51.8312101910828
Epoch [11/20], Loss: 51.8312101910828
Epoch [12/20], Loss: 51.8312101910828
Epoch [13/20], Loss: 51.8312101910828
Epoch [14/20], Loss: 51.877653927813164
Epoch [15/20], Loss: 51.8312101910828
Epoch [16/20], Loss: 51.8312101910828
Epoch [17/20], Loss: 51.78476645435244
Epoch [18/20], Loss: 51.8312101910828
Epoch [19/20], Loss: 51.78476645435244
Epoch [20/20], Loss: 51.8312101910828


In [80]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs = inputs.view(-1, input_size).to(device)  # Flatten the input tensor
        outputs = model(inputs)
        predicted = torch.argmax(outputs, dim =1).squeeze().float()
        # print(predicted.view(-1))
        total += labels.size(0)
        correct += (predicted == labels).sum().item()


In [81]:
print(f"Accuracy on test set: {100*correct / total}%")


Accuracy on test set: 48.165869218500795%
