In [None]:
import os
import random
import itertools
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, LongformerModel, LongformerConfig, LongformerTokenizer

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        text1 = self.texts.iloc[idx]['source_content']
        text2 = self.texts.iloc[idx]['suspicious_content']
        
        encoding1 = self.tokenizer.encode_plus(
            text1,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        encoding2 = self.tokenizer.encode_plus(
            text2,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        
        return {
            'input_ids_1': encoding1['input_ids'].squeeze(),
            'attention_mask_1': encoding1['attention_mask'].squeeze(),
            'input_ids_2': encoding2['input_ids'].squeeze(),
            'attention_mask_2': encoding2['attention_mask'].squeeze(),
            'labels': torch.tensor(self.labels[idx], dtype=torch.long)
        }


class SiameseLongformer(nn.Module): 
    def __init__(self, model_name, num_labels, max_length):
        super(SiameseLongformer, self).__init__()
        self.config = LongformerConfig.from_pretrained(model_name)
        self.config.output_hidden_states = False
        
        # Single encoder for both inputs (shared weights)
        self.longformer = LongformerModel.from_pretrained(model_name, config=self.config, gradient_checkpointing=True)
        
        # Concatenate embeddings, hence input size is doubled
        self.dropout = nn.Dropout(0.2)
        self.dense = nn.Linear(self.config.hidden_size * 2, 64)
        self.classifier = nn.Linear(64, num_labels)
        self.relu = nn.ReLU()
    
    def encode(self, input_ids, attention_mask):
        outputs = self.longformer(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        # Global average pooling
        return outputs.last_hidden_state.mean(dim=1)
    
    def forward(self, input_ids_1, attention_mask_1,
                input_ids_2, attention_mask_2):
        
        # Use the same encoder for both inputs
        embedding1 = self.encode(input_ids_1, attention_mask_1)
        embedding2 = self.encode(input_ids_2, attention_mask_2)
        
        # Concatenate embeddings
        concatenated = torch.cat((embedding1, embedding2), dim=1)
        
        # Classification layers
        x = self.dense(concatenated)
        x = self.relu(x)
        x = self.dropout(x)
        logits = self.classifier(x)
        
        return logits
    
def train_model(model, train_loader, val_loader, criterion, optimizer, schedular, num_epochs, device, map_label=None, early_stopping_patience=3):
    best_val_loss = float('inf')
    patience_counter = 0  # Counter to track early stopping

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        all_train_preds = []
        all_train_labels = []
        
        # Training loop
        for batch in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}'):
            input_ids_1 = batch['input_ids_1'].to(device)
            attention_mask_1 = batch['attention_mask_1'].to(device)
            input_ids_2 = batch['input_ids_2'].to(device)
            attention_mask_2 = batch['attention_mask_2'].to(device)
            labels = batch['labels'].to(device)
            
            optimizer.zero_grad()
            
            outputs = model(
                input_ids_1, attention_mask_1,
                input_ids_2, attention_mask_2
            )
            
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            
            # Collect predictions and labels
            _, predicted = torch.max(torch.softmax(outputs, dim=1), 1)
            all_train_preds.extend(predicted.cpu().numpy())
            all_train_labels.extend(labels.cpu().numpy())
        
        schedular.step()
        
        # Validation loop
        model.eval()
        val_loss = 0
        all_val_preds = []
        all_val_labels = []
        
        with torch.no_grad():
            for batch in val_loader:
                input_ids_1 = batch['input_ids_1'].to(device)
                attention_mask_1 = batch['attention_mask_1'].to(device)
                input_ids_2 = batch['input_ids_2'].to(device)
                attention_mask_2 = batch['attention_mask_2'].to(device)
                labels = batch['labels'].to(device)
                
                outputs = model(
                    input_ids_1, attention_mask_1,
                    input_ids_2, attention_mask_2
                )
                
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                
                # Collect predictions and labels
                _, predicted = torch.max(torch.softmax(outputs, dim=1), 1)
                all_val_preds.extend(predicted.cpu().numpy())
                all_val_labels.extend(labels.cpu().numpy())
         
        # Print current epoch's metrics
        print(f'\nEpoch {epoch + 1}:')
        print(f'Training Loss: {train_loss / len(train_loader):.4f}')
        print(f'Validation Loss: {val_loss / len(val_loader):.4f}')
        
        # Early stopping check
        avg_val_loss = val_loss / len(val_loader)
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), 'best_model.pt')  # Save best model
            patience_counter = 0  # Reset counter if validation loss improves
        else:
            patience_counter += 1
            print(f"Validation loss did not improve. Patience counter: {patience_counter}/{early_stopping_patience}")
        
        if patience_counter >= early_stopping_patience:
            print("Early stopping triggered.")
            break

        if map_label:
            # Training Metrics
            train_report = classification_report(
                [map_label[i] for i in all_train_labels],
                [map_label[i] for i in all_train_preds],
                output_dict=True
            )
            print("\nTraining Metrics:")
            print(train_report)
            
            # Validation Metrics
            val_report = classification_report(
                [map_label[i] for i in all_val_labels],
                [map_label[i] for i in all_val_preds],
                output_dict=True
            )
            print("\nValidation Metrics:")
            print(val_report)
    
    return all_val_preds, all_val_labels

    

def plot_confusion_matrix(cm, classes, title='Confusion matrix', cmap=plt.cm.Blues):
    cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title, fontsize=25)
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=90, fontsize=15)
    plt.yticks(tick_marks, classes, fontsize=15)
    
    fmt = '.2f'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black", fontsize=14)
    
    plt.ylabel('True label', fontsize=20)
    plt.xlabel('Predicted label', fontsize=20)

class DiceLoss(nn.Module):
    def __init__(self, class_weights=None, epsilon=1e-7):
        """
        Dice Loss for multi-class classification.
        :param class_weights: Tensor of class weights for balancing.
        :param epsilon: Smoothing factor to avoid division by zero.
        """
        super(DiceLoss, self).__init__()
        self.class_weights = class_weights
        self.epsilon = epsilon

    def forward(self, logits, targets):
        # Convert targets to one-hot encoding
        num_classes = logits.size(1)
        targets_one_hot = nn.functional.one_hot(targets, num_classes=num_classes).float()

        # Apply softmax to logits to get probabilities
        probs = nn.functional.softmax(logits, dim=1)

        # Compute intersection and cardinality
        intersection = torch.sum(probs * targets_one_hot, dim=0)
        cardinality = torch.sum(probs + targets_one_hot, dim=0)

        # Compute Dice coefficient
        dice = (2.0 * intersection + self.epsilon) / (cardinality + self.epsilon)

        # Compute Dice Loss
        dice_loss = 1.0 - dice

        # Apply class weights if provided
        if self.class_weights is not None:
            dice_loss = dice_loss * self.class_weights

        # Return average Dice Loss
        return dice_loss.mean()


def main():
    MAX_SEQUENCE_LENGTH = 4096
    MODEL_NAME = "allenai/longformer-base-4096"
    BATCH_SIZE = 14
    NUM_EPOCHS = 20
    LEARNING_RATE = 2e-5
    DEVICE = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu')

    # Set seed for reproducibility
    set_seed(33)
    
    # Load and preprocess data
    df = pd.read_csv('Data.csv')
    df['plagiarism_type'] = df['plagiarism_type'].factorize()[0]
    map_label = dict(enumerate(df['plagiarism_type'].factorize()[1]))
    
    X_train, X_test, y_train, y_test = train_test_split(
        df[['source_content', 'suspicious_content']], 
        df['plagiarism_type'].values,
        random_state=33,
        test_size=0.3
    )
    
    # Initialize tokenizer
    tokenizer = LongformerTokenizer.from_pretrained(MODEL_NAME)
    
    # Create datasets
    train_dataset = TextDataset(X_train, y_train, tokenizer, MAX_SEQUENCE_LENGTH)
    test_dataset = TextDataset(X_test, y_test, tokenizer, MAX_SEQUENCE_LENGTH)
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)
    
    # Initialize model
    model = SiameseLongformer(MODEL_NAME, len(map_label), MAX_SEQUENCE_LENGTH)
    model = model.to(DEVICE)
    
    # Initialize loss and optimizer
    from sklearn.utils.class_weight import compute_class_weight
    
    class_weights = compute_class_weight(class_weight="balanced", classes=np.unique(y_train), y=y_train)
    
    # Dice Loss
    class_weights = torch.tensor(class_weights, dtype=torch.float32).to(DEVICE)
    criterion = DiceLoss(class_weights=class_weights)


    # CrossEntropy Loss
    # criterion = nn.CrossEntropyLoss(weight=class_weights)
    # criterion = nn.CrossEntropyLoss()


    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE,weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.7)

    
    # Train model
    all_preds, all_labels = train_model(model, train_loader, test_loader, criterion, optimizer, scheduler,NUM_EPOCHS, DEVICE, map_label)

    
    # Plot confusion matrix
    cnf_matrix = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(7,7))
    plot_confusion_matrix(cnf_matrix, classes=list(map_label.values()))
    plt.show()

if __name__ == "__main__":
    main()