In [2]:
import os

import pandas as pd 
import torch
import torch.nn as nn
from tqdm import tqdm

from torch.utils.data import Dataset, DataLoader

In [3]:
from transformers import BertTokenizer, BertModel

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased',output_hidden_states = True)

In [4]:
class Data_Reader(Dataset):
    def __init__(self, X:pd.Series, Y:pd.Series, tokenizer, max_len:int = 184) -> None:
        super().__init__()
        self.texts = X['text'].to_list()
        self.labels = Y['label'].to_list()
        self.max_len = max_len
        self.tokenizer = tokenizer
        
    def __len__(self) -> int:
        return len(self.texts)
    
    def __getitem__(self, index) -> list:
        
        encoding = self.tokenizer(self.texts[index], return_tensors='pt',
                                     padding='max_length',max_length = self.max_len,
                                    truncation=True, add_special_tokens=True)
        return [
            torch.tensor(self.labels[index]).long(),
            {'input_ids':encoding['input_ids'],
             'attention_mask':encoding['attention_mask']
            }
            ]

In [68]:
X_train = pd.read_csv("../data/train/X.csv")[:100]
Y_train = pd.read_csv("../data/train/Y.csv")[:100]

X_val = pd.read_csv("../data/val/X.csv")[:20]
Y_val = pd.read_csv("../data/val/Y.csv")[:20]

In [69]:
train_data = Data_Reader(X_train,Y_train,tokenizer,max_len=136)
val_data = Data_Reader(X_val,Y_val,tokenizer,max_len=136)

In [70]:
train_data_batched = DataLoader(train_data, batch_size=8, shuffle=True, drop_last= True)
val_data_batched = DataLoader(val_data, batch_size=8, shuffle=True, drop_last= True)

In [71]:
batch_count = 0
for batch in tqdm(train_data_batched):
    if batch_count == 0:
        labels = batch[0]
        inputs = batch[1]
        for key, value in inputs.items():
            inputs[key] = value.squeeze(1)
        batch_count += 1
    else:
        break

  8%|▊         | 1/12 [00:00<00:00, 33.33it/s]


In [79]:
# class BiLSTM(nn.Module):
#     def __init__(self,hidden_size:int = 256, num_layers:int = 1) -> None:
#         super().__init__()
#         self.hidden_size = hidden_size
#         self.num_layers = num_layers
#         self.bert = BertModel.from_pretrained('bert-base-uncased',output_hidden_states= True)
#         self.drop_bert = nn.Dropout(0.1)
#         self.bilstm = nn.LSTM(input_size=768, hidden_size= hidden_size, num_layers= num_layers, bidirectional= True, batch_first= True)
#         self.dense = nn.Sequential(nn.Linear(in_features= hidden_size*2, out_features= 32),
#                                     nn.ReLU(),
#                                     nn.Linear(in_features= 32, out_features= 6),
#                                     nn.Softmax(dim=1))
        
#         # self.apply(init_weights)
        
#     def forward(self, x):
#         out = self.bert(**x).hidden_states[-2]
#         out = self.drop_bert(out)
#         lstm_out, _ = self.bilstm(out, None)
#         lstm_out = lstm_out[:,-1,:]
#         # for i in range(out.size(-2)):
#         #     if i == 0:
#         #         l_out, (hidden,cell) = self.bilstm(out[:,i,:].unsqueeze(1))
#         #     else:
#         #         l_out, (hidden,cell) = self.bilstm(out[:,i,:].unsqueeze(1),(hidden,cell))
        
#         # l_out = l_out.squeeze(1)
#         final_output = self.dense(lstm_out)
#         return final_output

class BERTClassifier(nn.Module):
    def __init__(self, num_classes= 6):
        super(BERTClassifier, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.dropout = nn.Dropout(0.1)
        self.fc = nn.Linear(self.bert.config.hidden_size, num_classes)
        
        # self.apply(init_weights)

    def forward(self, x):
        outputs = self.bert(**x)
        pooled_output = outputs.pooler_output
        x = self.dropout(pooled_output)
        logits = self.fc(x)
        return logits

In [80]:
def validate_accuracy(dataloader, model):
    model.eval()
    acc = 0
    for batch in dataloader:
        labels = batch[0]
        inputs = batch[1]
        for key, value in inputs.items():
            inputs[key] = value.squeeze(1)
        with torch.no_grad():
            output = model(inputs)
        _, preds = torch.max(output, dim=1)

        hits = sum(preds == labels)
        acc += hits/labels.size(0)
    avg_acc = acc/len(dataloader)
    
    return avg_acc

In [77]:
bertclassifier = BERTClassifier()
# bilstm = BiLSTM(hidden_size= 128)
model_opt = torch.optim.Adam(params= bertclassifier.parameters(), lr= 5e-5)
loss_function = nn.CrossEntropyLoss()

In [78]:
batch_loss, batch_count = [], 0

for epoch in range(2):
    # val_acc = validate_accuracy(val_data_batched, bertclassifier).item()
    # print(f"Epoch: {epoch}, Val Acc: {val_acc*100:.2f}%")
    for batch in tqdm(train_data_batched):
        bilstm.train()
        batch_count += 1
        labels = batch[0]
        inputs = batch[1]
        for key, value in inputs.items():
            inputs[key] = value.squeeze(1)
        
        output = bilstm(inputs)
        loss = loss_function(output,labels)
        
        model_opt.zero_grad()
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(bilstm.parameters(), 2)
        model_opt.step()
        
        val_acc = validate_accuracy(val_data_batched, bilstm).item()
        
        print(f"Batch: {batch_count} \t Loss: {loss.item():.5f} \t Val Acc: {val_acc*100:.2f}%")
        batch_loss.append(loss.item())
    

  8%|▊         | 1/12 [00:13<02:23, 13.01s/it]

Batch: 1 	 Loss: 1.77956 	 Val Acc: 56.25%


 17%|█▋        | 2/12 [00:25<02:09, 12.90s/it]

Batch: 2 	 Loss: 1.78419 	 Val Acc: 62.50%


 25%|██▌       | 3/12 [00:38<01:54, 12.74s/it]

Batch: 3 	 Loss: 1.77663 	 Val Acc: 50.00%


 33%|███▎      | 4/12 [00:51<01:44, 13.02s/it]

Batch: 4 	 Loss: 1.79105 	 Val Acc: 50.00%


 42%|████▏     | 5/12 [01:04<01:30, 12.90s/it]

Batch: 5 	 Loss: 1.78146 	 Val Acc: 50.00%


 50%|█████     | 6/12 [01:16<01:16, 12.68s/it]

Batch: 6 	 Loss: 1.77576 	 Val Acc: 56.25%


 58%|█████▊    | 7/12 [01:29<01:03, 12.74s/it]

Batch: 7 	 Loss: 1.78911 	 Val Acc: 43.75%


 67%|██████▋   | 8/12 [01:42<00:50, 12.72s/it]

Batch: 8 	 Loss: 1.77904 	 Val Acc: 56.25%


 75%|███████▌  | 9/12 [01:55<00:38, 12.85s/it]

Batch: 9 	 Loss: 1.77107 	 Val Acc: 50.00%


 83%|████████▎ | 10/12 [02:07<00:25, 12.72s/it]

Batch: 10 	 Loss: 1.78275 	 Val Acc: 43.75%


 92%|█████████▏| 11/12 [02:20<00:12, 12.65s/it]

Batch: 11 	 Loss: 1.78402 	 Val Acc: 31.25%


100%|██████████| 12/12 [02:33<00:00, 12.83s/it]
  0%|          | 0/12 [00:00<?, ?it/s]

Batch: 12 	 Loss: 1.76979 	 Val Acc: 18.75%


  8%|▊         | 1/12 [00:18<03:27, 18.88s/it]

Batch: 13 	 Loss: 1.76547 	 Val Acc: 18.75%


 17%|█▋        | 2/12 [00:31<02:34, 15.43s/it]

Batch: 14 	 Loss: 1.75351 	 Val Acc: 12.50%


 25%|██▌       | 3/12 [00:44<02:08, 14.23s/it]

Batch: 15 	 Loss: 1.78575 	 Val Acc: 25.00%


 33%|███▎      | 4/12 [00:57<01:48, 13.61s/it]

Batch: 16 	 Loss: 1.76209 	 Val Acc: 12.50%


 42%|████▏     | 5/12 [01:09<01:32, 13.21s/it]

Batch: 17 	 Loss: 1.74953 	 Val Acc: 18.75%


 50%|█████     | 6/12 [01:26<01:26, 14.43s/it]

Batch: 18 	 Loss: 1.75202 	 Val Acc: 12.50%


 58%|█████▊    | 7/12 [01:40<01:10, 14.10s/it]

Batch: 19 	 Loss: 1.75216 	 Val Acc: 25.00%


 67%|██████▋   | 8/12 [01:52<00:54, 13.61s/it]

Batch: 20 	 Loss: 1.74998 	 Val Acc: 18.75%


 75%|███████▌  | 9/12 [02:05<00:40, 13.37s/it]

Batch: 21 	 Loss: 1.77061 	 Val Acc: 18.75%


 83%|████████▎ | 10/12 [02:19<00:27, 13.54s/it]

Batch: 22 	 Loss: 1.80556 	 Val Acc: 18.75%


 92%|█████████▏| 11/12 [02:31<00:13, 13.24s/it]

Batch: 23 	 Loss: 1.77968 	 Val Acc: 12.50%


100%|██████████| 12/12 [02:44<00:00, 13.72s/it]

Batch: 24 	 Loss: 1.76044 	 Val Acc: 18.75%





In [25]:
def save_model(model, filepath):
    """
    Save PyTorch model parameters to a file.

    Args:
    - model (torch.nn.Module): PyTorch model to save.
    - filepath (str): Filepath to save the model parameters.
    """
    torch.save(model.state_dict(), filepath)
    print(f"Model parameters saved to '{filepath}'")

def load_model(model, filepath):
    """
    Load PyTorch model parameters from a file.

    Args:
    - model (torch.nn.Module): PyTorch model to load parameters into.
    - filepath (str): Filepath to the saved model parameters.
    """
    model.load_state_dict(torch.load(filepath))
    print(f"Model parameters loaded from '{filepath}'")


In [None]:
save_model(bertclassifier, "bert_model2.pth")

Model parameters saved to 'bert_model.pth'
