## Load SST dataset

In [None]:
import nltk
from nltk.tree import Tree
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch
from tqdm import tqdm

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

def read_ptb_tree(tree_string):
    return Tree.fromstring(tree_string)

def extract_sentence_and_label(tree):
    label = (tree.label())

    words = tree.leaves()
    sentence = ' '.join(words)

    return sentence, label

def read_file(file_path):
    data = []
    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            tree = read_ptb_tree(line.strip())
            sentence, label = extract_sentence_and_label(tree)
            data.append({'sentence': sentence, 'label': label})
    return data

In [None]:
train_path = '/kaggle/input/treeset/train.txt'
test_path = '/kaggle/input/treeset/test.txt'
dev_path = '/kaggle/input/treeset/dev.txt'

train_data = read_file(train_path)
test_data = read_file(test_path)
dev_data = read_file(dev_path)

In [None]:
import re
import unicodedata

def canonicalize_text(text):
    text = re.sub(r'[\d\W_]+', ' ', text)

    text = ''.join(
        c for c in unicodedata.normalize('NFD', text)
        if unicodedata.category(c) != 'Mn'
    )

    text = text.lower()

    text = text.strip()

    return text

In [None]:
class SST5_Dataset(Dataset):
    def __init__(self, file_path):
        self.data = [
            (
                tokenizer(
                    canonicalize_text(row['sentence']),
                    add_special_tokens=True,
                    max_length=512,  
                    padding='max_length', 
                    truncation=True,      
                    return_tensors="pt"   
                ),  
                int(row['label']) if isinstance(row['label'], str) else row['label']
            )
            for row in read_file(file_path)
        ]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        X,y = self.data[idx]
        input_ids = X['input_ids'].squeeze(0)
        attention_mask = X['attention_mask'].squeeze(0)
        label = torch.tensor(y, dtype=torch.long)
        return input_ids, attention_mask, label

In [None]:
import torch
from transformers import BertForSequenceClassification, BertTokenizer
from torch.utils.data import DataLoader, ConcatDataset

trainset = SST5_Dataset(train_path)
testset = SST5_Dataset(test_path)
valset = SST5_Dataset(dev_path)

In [None]:
# from torch.utils.data import Subset
# import random

# train_indices = random.sample(range(len(trainset)), 500)
# test_indices = random.sample(range(len(testset)), 100)
# val_indices = random.sample(range(len(valset)), 100)

# trainset = Subset(trainset, train_indices)
# testset = Subset(testset, test_indices)
# valset = Subset(valset, val_indices)

In [None]:
def read_unsupervised_sentences(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        sentences = file.readlines()
    return [sentence.strip() for sentence in sentences]

In [None]:
wiki_path = '/kaggle/input/sampled-wiki/processed_sentences.txt'

In [None]:
class WikiDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length=512):
        self.data = []
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                sentence = line.strip()
                tokenized = tokenizer(
                    sentence,
                    add_special_tokens=True,
                    max_length=max_length,
                    padding="max_length",
                    truncation=True,
                    return_tensors="pt"
                )
                self.data.append(tokenized)

        # self.data = self.data[:100]   ##sampling

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        tokenized = self.data[idx]
        input_ids = tokenized['input_ids'].squeeze(0) 
        attention_mask = tokenized['attention_mask'].squeeze(0)
        dummy_label = torch.tensor(-1, dtype=torch.long)
        return input_ids, attention_mask, dummy_label

In [None]:
wiki_dataset = WikiDataset(wiki_path, tokenizer)

In [None]:
unsup_trainset = ConcatDataset([wiki_dataset, trainset])

In [None]:
from transformers import BertModel

import logging
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import classification_report, f1_score, recall_score, accuracy_score

In [None]:
class UnSup_BERT(nn.Module):
    def __init__(self, bert, is_unsup_train=True):
        super(UnSup_BERT, self).__init__()

        self.bert = bert
        self.dropout = nn.Dropout(0.3)
        self.is_unsup_train = is_unsup_train

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=False)
        pooled = outputs['pooler_output']

        if not self.is_unsup_train:
            return pooled

        return self.dropout(pooled)

In [None]:
bert = BertModel.from_pretrained('bert-base-uncased')
uncl_model = UnSup_BERT(bert, is_unsup_train=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
uncl_model.to(device)

In [None]:
 # Train with Unsupervised SimCSE
def train_uncl(model, criterion, trainset, batch_size, epochs, path='/kaggle/working/best_model_uncl.pth'):
     train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)

     optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

     for epoch in range(epochs):
         model.train()

         train_loss = 0.0
         for batch in tqdm(train_loader, desc=f"Training SimCSE ...:"):
             b_ids, b_mask, *_ = batch

             b_ids = b_ids.to(device)
             b_mask = b_mask.to(device)

             optimizer.zero_grad()

             emb1 = model(b_ids, b_mask)
             emb2 = model(b_ids, b_mask)

             sim_matrix = F.cosine_similarity(emb1.unsqueeze(1), emb2.unsqueeze(0), dim=-1)
             sim_matrix = sim_matrix / 0.05
             labels_CL = torch.arange(b_ids.size(0)).long()
             labels_CL = labels_CL.to(sim_matrix.device) 
             
             loss = F.cross_entropy(sim_matrix, labels_CL)
                          
             loss.backward()
             optimizer.step()

             train_loss += loss.item()

         train_loss /= len(train_loader)

         print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {train_loss:.4f}")
        
     torch.save(model.state_dict(), path)   
        
     return model

train_uncl(uncl_model, nn.CrossEntropyLoss(), unsup_trainset, batch_size=4, epochs=3)


In [None]:
uncl_model = UnSup_BERT(bert, is_unsup_train=False)
uncl_model.load_state_dict(torch.load('/kaggle/working/best_model_uncl.pth'))

## Define BERT CLassifier

In [None]:
class BertClassifier(nn.Module):
    def __init__(self, num_labels):
        super(BertClassifier, self).__init__()
        self.bert = uncl_model
        # Frozen bert
        #self.bert.requires_grad_(False)

        #self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(0.5),
            nn.Linear(256, num_labels)
        )

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        #pooled_output = outputs.pooler_output
        #pooled_output = self.dropout(pooled_output)
        logits = self.classifier(outputs)
        return logits

## Define train and eval classification

In [None]:
def train_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    train_loss = 0.0

    for input_ids, attention_mask, labels in tqdm(dataloader, desc="Training"):
        input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)

        optimizer.zero_grad()
        logits = model(input_ids, attention_mask)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    train_loss /= len(dataloader)

    print(f"Train Loss: {train_loss:.4f}")
    return train_loss

def eval_one_epoch(model, dataloader, criterion, device):
    model.eval()
    eval_loss = 0.0

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for input_ids, attention_mask, labels in tqdm(dataloader, desc="Evaluating"):
          
            input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)

            logits = model(input_ids, attention_mask)
            loss = criterion(logits, labels)

            eval_loss += loss.item()

            preds = torch.argmax(logits, dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())

    eval_loss /= len(dataloader)
    accuracy = accuracy_score(all_labels, all_preds)
    print(f"Eval Loss: {eval_loss:.4f}, Accuracy: {accuracy:.4f}")

    return eval_loss, accuracy

In [None]:
def train_cls(model, criterion, trainset, valset, epochs, save_dir='/kaggle/working/'):
    train_loader = DataLoader(trainset, batch_size=32, shuffle=True)
    val_loader = DataLoader(valset, batch_size=32, shuffle=False)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
    
    best_model_path = f"{save_dir}best_model_cls.pth"

    best_val_loss = float('inf')
    for epoch in range(epochs):
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, accuracy= eval_one_epoch(model, val_loader, criterion, device)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), best_model_path)
            print(f"Save model at epoch {epoch + 1}")

    return model

model = BertClassifier(5)
model.to(device)

cls_model = train_cls(model, nn.CrossEntropyLoss(), trainset, valset, 6)

## Test and visual result

In [None]:
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np

def test_model(model_path, testset, device, batch_size=32):

    model = BertClassifier(num_labels=5)
    model.load_state_dict(torch.load(model_path))
    model.eval()
    model.to(device)
    
    test_loader = DataLoader(testset, batch_size=batch_size)
    
    all_preds = []
    all_labels = []
    total_loss = 0
    correct_predictions = 0
    
    with torch.no_grad():
        for input_ids, attention_mask, labels in tqdm(test_loader, desc="Testing"):
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            labels = labels.to(device)
            
            outputs = model(input_ids, attention_mask=attention_mask)
            loss = F.cross_entropy(outputs, labels)
            total_loss += loss.item()
            
            # Get predictions
            preds = torch.argmax(outputs, dim=1)
            correct_predictions += torch.sum(preds == labels).item()
            
            # Store predictions and labels for metric calculation
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # Calculate metrics
    accuracy = correct_predictions / len(testset)
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_labels, 
        all_preds, 
        average='weighted'
    )
    avg_loss = total_loss / len(test_loader)
    
    # Print metrics
    print(f"Test Loss: {avg_loss:.4f}")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")
    
    # Calculate and plot confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(
        cm, 
        annot=True, 
        fmt='d', 
        cmap='Blues',
        xticklabels=np.unique(all_labels),
        yticklabels=np.unique(all_labels)
    )
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.tight_layout()
    plt.show()
    
    metrics = {
        'loss': avg_loss,
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'confusion_matrix': cm
    }
    
    return metrics

model_path = '/kaggle/working/best_model_cls.pth'

metrics = test_model(
    model_path=model_path,
    testset=testset,
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
)