In [None]:
!export CUDA_LAUNCH_BLOCKING=1
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import time, re, pickle
from collections import Counter
import nltk
from nltk.corpus import brown
from nltk.tokenize import TreebankWordTokenizer, casual_tokenize
from tqdm import tqdm
import pandas as pd
nltk.download('brown')
nltk.download('punkt')
import sys
import random
import os

[nltk_data] Downloading package brown to /usr/share/nltk_data...
[nltk_data]   Package brown is already up-to-date!
[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [None]:
##############################

##############################
class ELMoBiLM(nn.Module):
    """
    This ELMo model does NOT combine e0, e1, e2.
    It just provides them separately:
      - e0: input embeddings
      - e1: forward hidden states
      - e2: backward hidden states
    We'll do the combination in the downstream classification step.
    """
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers=2, pretrained_embeddings=None):
        super(ELMoBiLM, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        if pretrained_embeddings is not None:
            self.embedding.weight.data.copy_(pretrained_embeddings)
        self.lstm1 = nn.LSTM(
            embed_dim, hidden_dim, num_layers=num_layers,
            batch_first=True, bidirectional=True
        )
        self.lstm2 = nn.LSTM(
            hidden_dim*2, hidden_dim, num_layers=num_layers,
            batch_first=True, bidirectional=True
        )
        self.forward_linear = nn.Linear(hidden_dim, vocab_size)
        self.backward_linear = nn.Linear(hidden_dim, vocab_size)

    def forward(self, input_ids):
        emb = self.embedding(input_ids)  
        outputs1, _ = self.lstm1(emb)      
        outputs2, _ = self.lstm2(outputs1)
        hf = outputs2[:, :, :outputs2.size(2)//2]   
        hb = outputs2[:, :, outputs2.size(2)//2:]   
        
        forward_logits = self.forward_linear(hf)
        backward_logits = self.backward_linear(hb)
        return forward_logits, backward_logits, (emb, hf, hb)


In [3]:
class Tokenizer:
    def __init__(self):
        self.treebank_tokenizer = TreebankWordTokenizer()
        
    def preprocess_special_cases(self, text):
        text = re.sub(r'https?://\S+|www\.\S+', 'URL', text)
        text = re.sub(r'#\w+', 'HASHTAG', text)
        text = re.sub(r'@\w+', 'MENTION', text)
        text = re.sub(r'\b\d+%|\b\d+\s?percent\b', 'PERCENTAGE', text, flags=re.IGNORECASE)
        text = re.sub(r'\b\d+\s?(years old|yo|yrs|yr)\b', 'AGE', text, flags=re.IGNORECASE)
        text = re.sub(r'\b\d{1,2}:\d{2}\s?(AM|PM|am|pm)?\b', 'TIME', text)
        text = re.sub(r'\b\d+\s?(hours|hrs|minutes|mins|seconds|secs|days|weeks|months|years)\b', 'TIMEPERIOD', text, flags=re.IGNORECASE)
        return text
    
    def custom_sentence_split(self, text):
        abbreviations = [
            "Mr.", "Dr.", "Ms.", "Mrs.", "Prof.", "Sr.", "Jr.", "Ph.D.", "M.D.",
            "B.A.", "M.A.", "D.D.S.", "D.V.M.", "LL.D.", "B.C.", "a.m.", "p.m.",
            "etc.", "e.g.", "i.e.", "vs.", "Jan.", "Feb.", "Mar.", "Apr.", "Jun.",
            "Jul.", "Aug.", "Sep.", "Oct.", "Nov.", "Dec."
        ]
        for abbr in abbreviations:
            text = text.replace(abbr, abbr.replace(".", "<DOT>"))
        sentences = re.split(r'(?<=[.!?])\s+', text.strip())
        sentences = [s.replace("<DOT>", ".") for s in sentences]
        return sentences

    def preprocess(self, text):
        text = self.preprocess_special_cases(text)
        sentences = self.custom_sentence_split(text)
        tokenized_sentences = []
        for sentence in sentences:
            casual_tokens = casual_tokenize(sentence, preserve_case=True)
            tokens = []
            for token in casual_tokens:
                tokens.extend(self.treebank_tokenizer.tokenize(token))
            tokenized_sentences.append(self.add_special_tokens(tokens))
        return tokenized_sentences

    def add_special_tokens(self, tokens):
        return ['START'] + tokens + ['END']

    def tokenize(self, text):
        return self.preprocess(text)


In [None]:

class Tokenizer:
    def __init__(self):
        self.treebank_tokenizer = TreebankWordTokenizer()
        
    def preprocess_special_cases(self, text):
        text = re.sub(r'https?://\S+|www\.\S+', 'URL', text)
        text = re.sub(r'#\w+', 'HASHTAG', text)
        text = re.sub(r'@\w+', 'MENTION', text)
        text = re.sub(r'\b\d+%|\b\d+\s?percent\b', 'PERCENTAGE', text, flags=re.IGNORECASE)
        text = re.sub(r'\b\d+\s?(years old|yo|yrs|yr)\b', 'AGE', text, flags=re.IGNORECASE)
        text = re.sub(r'\b\d{1,2}:\d{2}\s?(AM|PM|am|pm)?\b', 'TIME', text)
        text = re.sub(r'\b\d+\s?(hours|hrs|minutes|mins|seconds|secs|days|weeks|months|years)\b', 'TIMEPERIOD', text, flags=re.IGNORECASE)
        return text
    
    def custom_sentence_split(self, text):
        abbreviations = [
            "Mr.", "Dr.", "Ms.", "Mrs.", "Prof.", "Sr.", "Jr.", "Ph.D.", "M.D.",
            "B.A.", "M.A.", "D.D.S.", "D.V.M.", "LL.D.", "B.C.", "a.m.", "p.m.",
            "etc.", "e.g.", "i.e.", "vs.", "Jan.", "Feb.", "Mar.", "Apr.", "Jun.",
            "Jul.", "Aug.", "Sep.", "Oct.", "Nov.", "Dec."
        ]
        for abbr in abbreviations:
            text = text.replace(abbr, abbr.replace(".", "<DOT>"))
        sentences = re.split(r'(?<=[.!?])\s+', text.strip())
        sentences = [s.replace("<DOT>", ".") for s in sentences]
        return sentences

    def preprocess(self, text):
        text = self.preprocess_special_cases(text)
        sentences = self.custom_sentence_split(text)
        tokenized_sentences = []
        for sentence in sentences:
            casual_tokens = casual_tokenize(sentence, preserve_case=True)
            tokens = []
            for token in casual_tokens:
                tokens.extend(self.treebank_tokenizer.tokenize(token))
            tokenized_sentences.append(self.add_special_tokens(tokens))
        return tokenized_sentences

    def add_special_tokens(self, tokens):
        return ['START'] + tokens + ['END']

    def tokenize(self, text):
        return self.preprocess(text)

def sentence_to_indices(sentence, vocab):
    return [vocab.get(token, vocab["<unk>"]) for token in sentence]





class FrozenLambdaElmoClassifier(nn.Module):

    def __init__(self, elmo_model, num_classes, rnn_hidden_size=128, rnn_layers=1, fixed_lambdas=None):
        super(FrozenLambdaElmoClassifier, self).__init__()
        self.elmo = elmo_model
        for param in self.elmo.parameters():
            param.requires_grad = False

        self.embed_dim = self.elmo.embedding.embedding_dim  
        self.hidden_dim = self.elmo.lstm1.hidden_size         

        if self.embed_dim != self.hidden_dim:
            self.e0_proj = nn.Linear(self.embed_dim, self.hidden_dim)
        else:
            self.e0_proj = nn.Identity()

        if fixed_lambdas is None:
            self.lambdas = [1/3, 1/3, 1/3]
        else:
            self.lambdas = fixed_lambdas

        self.classifier_rnn = nn.LSTM(input_size=self.hidden_dim,
                                      hidden_size=rnn_hidden_size,
                                      num_layers=rnn_layers,
                                      batch_first=True)
        self.fc = nn.Linear(rnn_hidden_size, num_classes)

    def forward(self, input_ids):
        _, _, (emb, hf, hb) = self.elmo(input_ids)
        e0 = self.e0_proj(emb)
        e1 = hf
        e2 = hb
        combined = self.lambdas[0] * e0 + self.lambdas[1] * e1 + self.lambdas[2] * e2
        rnn_out, (h_n, _) = self.classifier_rnn(combined)
        last_hidden = h_n[-1]
        logits = self.fc(last_hidden)
        return logits


class TrainableLambdaElmoClassifier(nn.Module):

    def __init__(self, elmo_model, num_classes, rnn_hidden_size=128, rnn_layers=1):
        super(TrainableLambdaElmoClassifier, self).__init__()
        self.elmo = elmo_model
        for param in self.elmo.parameters():
            param.requires_grad = False

        self.embed_dim = self.elmo.embedding.embedding_dim
        self.hidden_dim = self.elmo.lstm1.hidden_size

        if self.embed_dim != self.hidden_dim:
            self.e0_proj = nn.Linear(self.embed_dim, self.hidden_dim)
        else:
            self.e0_proj = nn.Identity()

        self.lambda_params = nn.Parameter(torch.zeros(3))

        self.classifier_rnn = nn.LSTM(input_size=self.hidden_dim,
                                      hidden_size=rnn_hidden_size,
                                      num_layers=rnn_layers,
                                      batch_first=True)
        self.fc = nn.Linear(rnn_hidden_size, num_classes)

    def forward(self, input_ids):
        _, _, (emb, hf, hb) = self.elmo(input_ids)
        e0 = self.e0_proj(emb)
        e1 = hf
        e2 = hb
        lambdas = torch.softmax(self.lambda_params, dim=0)
        combined = lambdas[0] * e0 + lambdas[1] * e1 + lambdas[2] * e2
        rnn_out, (h_n, _) = self.classifier_rnn(combined)
        last_hidden = h_n[-1]
        logits = self.fc(last_hidden)
        return logits


class LearnableFunctionElmoClassifier(nn.Module):

    def __init__(self, elmo_model, num_classes, rnn_hidden_size=128, rnn_layers=1):
        super(LearnableFunctionElmoClassifier, self).__init__()
        self.elmo = elmo_model
        for param in self.elmo.parameters():
            param.requires_grad = False

        self.embed_dim = self.elmo.embedding.embedding_dim
        self.hidden_dim = self.elmo.lstm1.hidden_size

        
        if self.embed_dim != self.hidden_dim:
            self.e0_proj = nn.Linear(self.embed_dim, self.hidden_dim)
        else:
            self.e0_proj = nn.Identity()

        self.mlp_combiner = nn.Sequential(
            nn.Linear(self.hidden_dim * 3, self.hidden_dim),
            nn.ReLU(),
            
        )

        self.classifier_rnn = nn.LSTM(input_size=self.hidden_dim,
                                      hidden_size=rnn_hidden_size,
                                      num_layers=rnn_layers,
                                      batch_first=True)
        self.fc = nn.Linear(rnn_hidden_size, num_classes)

    def forward(self, input_ids):
        
        with torch.no_grad():
            _, _, (emb, hf, hb) = self.elmo(input_ids)
        e0 = self.e0_proj(emb)  
        e1 = hf               
        e2 = hb               

        
        concatenated = torch.cat([e0, e1, e2], dim=2)  
        
        combined_token = self.mlp_combiner(concatenated)  

        
        rnn_out, (h_n, _) = self.classifier_rnn(combined_token)
        last_hidden = h_n[-1]  
        logits = self.fc(last_hidden)
        return logits



class NewsDataset(Dataset):
    def __init__(self, dataframe, vocab, tokenizer):
        self.data = dataframe
        self.vocab = vocab
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data.iloc[idx]['Description']
        label = self.data.iloc[idx]['Class Index']-1
        
        tokenized_sentences = self.tokenizer.tokenize(text)
        
        tokens = [token for sentence in tokenized_sentences for token in sentence]
        indices = sentence_to_indices(tokens, self.vocab)
        return indices, label

def collate_classification(batch):
    pad_idx = 0
    max_len = max(len(x[0]) for x in batch)
    input_batch, label_batch = [], []
    for input_ids, label in batch:
        seq_len = len(input_ids)
        pad_tensor = torch.full((max_len - seq_len,), pad_idx, dtype=torch.long)
        input_tensor = torch.tensor(input_ids, dtype=torch.long)
        input_batch.append(torch.cat([input_tensor, pad_tensor]))
        label_batch.append(label)
    return torch.stack(input_batch), torch.tensor(label_batch, dtype=torch.long)


def train_classifier(model, dataloader, device, epochs=3, lr=1e-3):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        total_correct = 0
        total_examples = 0
        model.train()
        total_loss = 0.0
        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}")
        for inputs, labels in pbar:
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            logits = model(inputs)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            pbar.set_postfix(loss=f"{total_loss/len(dataloader):.4f}")
            predictions = logits.argmax(dim=1)
            total_correct += (predictions == labels).sum().item()
            total_examples += labels.size(0)
        acc = total_correct / total_examples
        avg_loss = total_loss/len(dataloader)
        print(f"Epoch {epoch+1} finished. Loss: {avg_loss:.4f}, Training Accuracy: {acc*100:.2f}%")
    return model



def main():
    
    EMBED_DIM = 100
    HIDDEN_DIM = 256
    NUM_LAYERS = 2
    BATCH_SIZE = 32
    CLASSIFIER_EPOCHS = 10
    LEARNING_RATE = 1e-3

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    vocab_path = "/kaggle/input/idk/pytorch/default/1/vocab.pkl"
    bilstm_path = "/kaggle/input/idk/pytorch/default/1/bilstm.pt"
    if not os.path.exists(vocab_path):
        raise FileNotFoundError(f"Vocabulary file not found at {vocab_path}")
    if not os.path.exists(bilstm_path):
        raise FileNotFoundError(f"Pre-trained ELMo model not found at {bilstm_path}")

    with open(vocab_path, "rb") as f:
        vocab = pickle.load(f)
    vocab_size = len(vocab)
    elmo_model = ELMoBiLM(vocab_size, EMBED_DIM, HIDDEN_DIM, num_layers=NUM_LAYERS)
    elmo_model.load_state_dict(torch.load(bilstm_path, map_location=device))
    elmo_model.eval()

    csv_path = "/kaggle/input/news-classification/train.csv"
    df = pd.read_csv(csv_path)

    tokenizer = Tokenizer()
    news_dataset = NewsDataset(df, vocab, tokenizer)
    news_dataloader = DataLoader(news_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_classification, num_workers=4, pin_memory=True)

    num_classes = df['Class Index'].nunique()
    classifiers = {
        "frozen": FrozenLambdaElmoClassifier(elmo_model, num_classes=num_classes),
        "trainable": TrainableLambdaElmoClassifier(elmo_model, num_classes=num_classes),
        "learnable": LearnableFunctionElmoClassifier(elmo_model, num_classes=num_classes)
    }

    for method_name, classifier in classifiers.items():
        print(f"\nTraining classifier with method: {method_name}")
        classifier = train_classifier(classifier, news_dataloader, device, epochs=CLASSIFIER_EPOCHS, lr=LEARNING_RATE)
        save_path = f"classifier_{method_name}.pt"
        torch.save(classifier.state_dict(), save_path)
        print(f"Saved {method_name} classifier to {save_path}")

In [None]:
def load_static_embedding_model(filepath, device):
    state = torch.load(filepath, map_location=device, weights_only=True)
    embeddings_dict = state["embeddings"]
    word_to_index = state["word_to_index"]
    vocab_size = len(word_to_index)
    
    sample_vec = next(iter(embeddings_dict.values()))
    embed_dim = len(sample_vec)
    
    embedding_matrix = torch.zeros(vocab_size, embed_dim)
    for word, idx in word_to_index.items():
        if word in embeddings_dict:
            vec = embeddings_dict[word]
        else:
            vec = [0.0] * embed_dim
        embedding_matrix[idx] = torch.tensor(vec, dtype=torch.float)
    embedding_layer = nn.Embedding(vocab_size, embed_dim)
    embedding_layer.weight.data.copy_(embedding_matrix)
    embedding_layer.weight.requires_grad = False  

    
    class StaticEmbeddingWrapper(nn.Module):
        def __init__(self, embedding, vocab):
            super().__init__()
            self.embedding = embedding
            self.vocab = vocab
    return StaticEmbeddingWrapper(embedding_layer, word_to_index).to(device)


def sentence_to_indices(sentence, vocab):
    return [vocab.get(token, vocab.get("<unk>", 1)) for token in sentence]

class NewsDataset(Dataset):
    """
    News classification dataset reading from a CSV file.
    Assumes the CSV file has columns 'text' and 'label'.
    Uses the given tokenizer and vocabulary.
    """
    def __init__(self, dataframe, vocab, tokenizer):
        self.data = dataframe
        self.vocab = vocab
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data.iloc[idx]['Description']
        label = self.data.iloc[idx]['Class Index']-1
        tokenized_sentences = self.tokenizer.tokenize(text)
        tokens = [token for sentence in tokenized_sentences for token in sentence]
        indices = sentence_to_indices(tokens, self.vocab)
        return indices, label

def collate_classification(batch):
    pad_idx = 0
    max_len = max(len(x[0]) for x in batch)
    input_batch, label_batch = [], []
    for input_ids, label in batch:
        seq_len = len(input_ids)
        pad_tensor = torch.full((max_len - seq_len,), pad_idx, dtype=torch.long)
        input_tensor = torch.tensor(input_ids, dtype=torch.long)
        input_batch.append(torch.cat([input_tensor, pad_tensor]))
        label_batch.append(label)
    return torch.stack(input_batch), torch.tensor(label_batch, dtype=torch.long)

class CBOWClassifier(nn.Module):
    def __init__(self, static_model, num_classes, rnn_hidden_size=128, rnn_layers=1):
        super(CBOWClassifier, self).__init__()
        self.embedding = static_model.embedding
        for param in self.embedding.parameters():
            param.requires_grad = False
        self.vocab = static_model.vocab
        embed_dim = self.embedding.embedding_dim
        self.rnn = nn.LSTM(input_size=embed_dim,
                           hidden_size=rnn_hidden_size,
                           num_layers=rnn_layers,
                           batch_first=True)
        self.fc = nn.Linear(rnn_hidden_size, num_classes)
    
    def forward(self, input_ids):
        emb = self.embedding(input_ids)  
        _, (h_n, _) = self.rnn(emb)
        last_hidden = h_n[-1]
        logits = self.fc(last_hidden)
        return logits

class SkipgramClassifier(nn.Module):
    def __init__(self, static_model, num_classes, rnn_hidden_size=128, rnn_layers=1):
        super(SkipgramClassifier, self).__init__()
        self.embedding = static_model.embedding
        for param in self.embedding.parameters():
            param.requires_grad = False
        self.vocab = static_model.vocab
        embed_dim = self.embedding.embedding_dim
        self.rnn = nn.LSTM(input_size=embed_dim,
                           hidden_size=rnn_hidden_size,
                           num_layers=rnn_layers,
                           batch_first=True)
        self.fc = nn.Linear(rnn_hidden_size, num_classes)
    
    def forward(self, input_ids):
        emb = self.embedding(input_ids)
        _, (h_n, _) = self.rnn(emb)
        last_hidden = h_n[-1]
        logits = self.fc(last_hidden)
        return logits

class SVDClassifier(nn.Module):
    def __init__(self, static_model, num_classes, rnn_hidden_size=128, rnn_layers=1):
        super(SVDClassifier, self).__init__()
        self.embedding = static_model.embedding
        for param in self.embedding.parameters():
            param.requires_grad = False
        self.vocab = static_model.vocab
        embed_dim = self.embedding.embedding_dim
        self.rnn = nn.LSTM(input_size=embed_dim,
                           hidden_size=rnn_hidden_size,
                           num_layers=rnn_layers,
                           batch_first=True)
        self.fc = nn.Linear(rnn_hidden_size, num_classes)
    
    def forward(self, input_ids):
        emb = self.embedding(input_ids)
        _, (h_n, _) = self.rnn(emb)
        last_hidden = h_n[-1]
        logits = self.fc(last_hidden)
        return logits


def train_classifier(model, dataloader, device, epochs=3, lr=1e-3):
    model.to(device)
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        total_loss = 0.0
        total_correct = 0
        total_examples = 0
        pbar = tqdm(dataloader, desc=f"Classifier Training Epoch {epoch+1}")
        for inputs, labels in pbar:
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            logits = model(inputs)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            pbar.set_postfix(loss=f"{total_loss/len(dataloader):.4f}")
            predictions = logits.argmax(dim=1)
            total_correct += (predictions == labels).sum().item()
            total_examples += labels.size(0)
        acc = total_correct / total_examples
        avg_loss = total_loss/len(dataloader)
        print(f"Epoch {epoch+1} finished. Loss: {avg_loss:.4f}, Training Accuracy: {acc*100:.2f}%")
    return model


In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
def evaluate_model(model, dataloader, device):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Evaluating", leave=False):
            inputs = inputs.to(device)
            outputs = model(inputs)
            preds = outputs.argmax(dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.numpy())
    return all_labels, all_preds

def load_static_embedding_model(filepath, device):
    state = torch.load(filepath, map_location=device, weights_only=True)
    embeddings_dict = state["embeddings"]
    word_to_index = state["word_to_index"]
    vocab_size = len(word_to_index)
    sample_vec = next(iter(embeddings_dict.values()))
    embed_dim = len(sample_vec)
    embedding_matrix = torch.zeros(vocab_size, embed_dim)
    for word, idx in word_to_index.items():
        vec = embeddings_dict[word] if word in embeddings_dict else [0.0] * embed_dim
        embedding_matrix[idx] = torch.tensor(vec, dtype=torch.float)
    embedding_layer = nn.Embedding(vocab_size, embed_dim)
    embedding_layer.weight.data.copy_(embedding_matrix)
    embedding_layer.weight.requires_grad = False
    class StaticEmbeddingWrapper(nn.Module):
        def __init__(self, embedding, vocab):
            super().__init__()
            self.embedding = embedding
            self.vocab = vocab
    return StaticEmbeddingWrapper(embedding_layer, word_to_index).to(device)

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_csv_path = "/kaggle/input/news-classification/train.csv"
    test_csv_path = "/kaggle/input/news-classification/test.csv"  
    train_df = pd.read_csv(train_csv_path)
    test_df = pd.read_csv(test_csv_path)
    
    tokenizer = Tokenizer()
    
    num_classes = train_df['Class Index'].nunique()

    vocab_path = "/kaggle/input/idk/pytorch/default/1/vocab.pkl"
    bilstm_path = "/kaggle/input/idk/pytorch/default/1/bilstm.pt"
    with open(vocab_path, "rb") as f:
        vocab = pickle.load(f)
    vocab_size = len(vocab)
    EMBED_DIM = 100
    HIDDEN_DIM = 256
    NUM_LAYERS = 2

    elmo_model = ELMoBiLM(vocab_size, EMBED_DIM, HIDDEN_DIM, num_layers=NUM_LAYERS)
    elmo_model.load_state_dict(torch.load(bilstm_path, map_location=device))
    elmo_model.eval()
    
    elmo_classifiers = {
        "frozen": FrozenLambdaElmoClassifier(elmo_model, num_classes=num_classes),
        "trainable": TrainableLambdaElmoClassifier(elmo_model, num_classes=num_classes),
        "learnable": LearnableFunctionElmoClassifier(elmo_model, num_classes=num_classes)
    }
    for name, model in elmo_classifiers.items():
        model_path = f"/kaggle/input/newmodels/pytorch/default/1/classifier2_{name}.pt"
        state_dict = torch.load(model_path, map_location=device)
        model.load_state_dict(state_dict)
        model.to(device)
    
    dataset_train_elmo = NewsDataset(train_df, vocab, tokenizer)
    dataset_test_elmo = NewsDataset(test_df, vocab, tokenizer)
    loader_train_elmo = DataLoader(dataset_train_elmo, batch_size=32, shuffle=False, collate_fn=collate_classification, num_workers=4)
    loader_test_elmo = DataLoader(dataset_test_elmo, batch_size=32, shuffle=False, collate_fn=collate_classification, num_workers=4)
    
    print("=== ELMo-based Classifiers Evaluation ===")
    for name, model in elmo_classifiers.items():
        print(f"\n--- Model: {name} ---")
        
        true_train, pred_train = evaluate_model(model, loader_train_elmo, device)
        acc_train = accuracy_score(true_train, pred_train)
        prec_train = precision_score(true_train, pred_train, average='macro', zero_division=0)
        rec_train = recall_score(true_train, pred_train, average='macro', zero_division=0)
        f1_train = f1_score(true_train, pred_train, average='macro', zero_division=0)
        cm_train = confusion_matrix(true_train, pred_train)
        print("Train Metrics:")
        print(f"Accuracy : {acc_train:.4f}")
        print(f"Precision: {prec_train:.4f}")
        print(f"Recall   : {rec_train:.4f}")
        print(f"F1 Score : {f1_train:.4f}")
        print("Confusion Matrix:")
        print(cm_train)
        
        
        true_test, pred_test = evaluate_model(model, loader_test_elmo, device)
        acc_test = accuracy_score(true_test, pred_test)
        prec_test = precision_score(true_test, pred_test, average='macro', zero_division=0)
        rec_test = recall_score(true_test, pred_test, average='macro', zero_division=0)
        f1_test = f1_score(true_test, pred_test, average='macro', zero_division=0)
        cm_test = confusion_matrix(true_test, pred_test)
        print("\nTest Metrics:")
        print(f"Accuracy : {acc_test:.4f}")
        print(f"Precision: {prec_test:.4f}")
        print(f"Recall   : {rec_test:.4f}")
        print(f"F1 Score : {f1_test:.4f}")
        print("Confusion Matrix:")
        print(cm_test)
    

    cbow_path = "/kaggle/input/a3_models/pytorch/default/1/cbow.pt"
    skipgram_path = "/kaggle/input/a3_models/pytorch/default/1/skipgram.pt"
    svd_path = "/kaggle/input/a3_models/pytorch/default/1/svd.pt"
    cbow_static = load_static_embedding_model(cbow_path, device)
    skipgram_static = load_static_embedding_model(skipgram_path, device)
    svd_static = load_static_embedding_model(svd_path, device)
    
    
    static_classifiers = {
        "cbow": CBOWClassifier(cbow_static, num_classes=num_classes),
        "skipgram": SkipgramClassifier(skipgram_static, num_classes=num_classes),
        "svd": SVDClassifier(svd_static, num_classes=num_classes)
    }
    
    for name, model in static_classifiers.items():
        model_path = f"/kaggle/input/newstatic/pytorch/default/1/classifier2_{name}.pt"
        state_dict = torch.load(model_path, map_location=device)
        model.load_state_dict(state_dict)
        model.to(device)
    
    
    print("\n=== Static Embedding Classifiers Evaluation ===")
    for name, model in static_classifiers.items():
        vocab_static = model.vocab
        dataset_train_static = NewsDataset(train_df, vocab_static, tokenizer)
        dataset_test_static = NewsDataset(test_df, vocab_static, tokenizer)
        loader_train_static = DataLoader(dataset_train_static, batch_size=32, shuffle=False, collate_fn=collate_classification, num_workers=4)
        loader_test_static = DataLoader(dataset_test_static, batch_size=32, shuffle=False, collate_fn=collate_classification, num_workers=4)
        
        print(f"\n--- Model: {name} ---")
        
        true_train, pred_train = evaluate_model(model, loader_train_static, device)
        acc_train = accuracy_score(true_train, pred_train)
        prec_train = precision_score(true_train, pred_train, average='macro', zero_division=0)
        rec_train = recall_score(true_train, pred_train, average='macro', zero_division=0)
        f1_train = f1_score(true_train, pred_train, average='macro', zero_division=0)
        cm_train = confusion_matrix(true_train, pred_train)
        print("Train Metrics:")
        print(f"Accuracy : {acc_train:.4f}")
        print(f"Precision: {prec_train:.4f}")
        print(f"Recall   : {rec_train:.4f}")
        print(f"F1 Score : {f1_train:.4f}")
        print("Confusion Matrix:")
        print(cm_train)
        
        
        true_test, pred_test = evaluate_model(model, loader_test_static, device)
        acc_test = accuracy_score(true_test, pred_test)
        prec_test = precision_score(true_test, pred_test, average='macro', zero_division=0)
        rec_test = recall_score(true_test, pred_test, average='macro', zero_division=0)
        f1_test = f1_score(true_test, pred_test, average='macro', zero_division=0)
        cm_test = confusion_matrix(true_test, pred_test)
        print("\nTest Metrics:")
        print(f"Accuracy : {acc_test:.4f}")
        print(f"Precision: {prec_test:.4f}")
        print(f"Recall   : {rec_test:.4f}")
        print(f"F1 Score : {f1_test:.4f}")
        print("Confusion Matrix:")
        print(cm_test)

if __name__ == "__main__":
    main()

  elmo_model.load_state_dict(torch.load(bilstm_path, map_location=device))
  state_dict = torch.load(model_path, map_location=device)


=== ELMo-based Classifiers Evaluation ===

--- Model: frozen ---


                                                               

Train Metrics:
Accuracy : 0.9239
Precision: 0.9257
Recall   : 0.9239
F1 Score : 0.9241
Confusion Matrix:
[[27344   631   743  1282]
 [  168 29339   117   376]
 [  480   175 26272  3073]
 [  466   229  1390 27915]]


                                                             


Test Metrics:
Accuracy : 0.8450
Precision: 0.8477
Recall   : 0.8450
F1 Score : 0.8455
Confusion Matrix:
[[1600   80   84  136]
 [  38 1760   24   78]
 [  72   34 1490  304]
 [  65   65  198 1572]]

--- Model: trainable ---


                                                               

Train Metrics:
Accuracy : 0.9345
Precision: 0.9346
Recall   : 0.9345
F1 Score : 0.9344
Confusion Matrix:
[[28020   504   775   701]
 [  323 29373   112   192]
 [  517   224 27079  2180]
 [  687   288  1359 27666]]


                                                             


Test Metrics:
Accuracy : 0.8501
Precision: 0.8499
Recall   : 0.8501
F1 Score : 0.8499
Confusion Matrix:
[[1615   83  107   95]
 [  64 1756   40   40]
 [  98   41 1525  236]
 [  96   64  175 1565]]

--- Model: learnable ---


                                                               

Train Metrics:
Accuracy : 0.8921
Precision: 0.8927
Recall   : 0.8921
F1 Score : 0.8919
Confusion Matrix:
[[26677   927  1076  1320]
 [  348 29037   168   447]
 [  924   479 24994  3603]
 [ 1063   607  1987 26343]]


                                                             


Test Metrics:
Accuracy : 0.8582
Precision: 0.8587
Recall   : 0.8582
F1 Score : 0.8578
Confusion Matrix:
[[1639   79   80  102]
 [  38 1798   22   42]
 [  76   46 1501  277]
 [  85   62  169 1584]]


  state_dict = torch.load(model_path, map_location=device)



=== Static Embedding Classifiers Evaluation ===

--- Model: cbow ---


                                                               

Train Metrics:
Accuracy : 0.8662
Precision: 0.8671
Recall   : 0.8662
F1 Score : 0.8650
Confusion Matrix:
[[25357  1918  1656  1069]
 [  358 28972   250   420]
 [  822   663 26410  2105]
 [ 1657  1318  3822 23203]]


                                                             


Test Metrics:
Accuracy : 0.7875
Precision: 0.7882
Recall   : 0.7875
F1 Score : 0.7856
Confusion Matrix:
[[1442  174  163  121]
 [  69 1739   31   61]
 [  99   73 1525  203]
 [ 123  146  352 1279]]

--- Model: skipgram ---


                                                               

Train Metrics:
Accuracy : 0.9166
Precision: 0.9170
Recall   : 0.9166
F1 Score : 0.9168
Confusion Matrix:
[[27593   609   735  1063]
 [  498 28659   373   470]
 [  754   203 27008  2035]
 [  845   207  2212 26736]]


                                                             


Test Metrics:
Accuracy : 0.8333
Precision: 0.8338
Recall   : 0.8333
F1 Score : 0.8335
Confusion Matrix:
[[1598   88  106  108]
 [  80 1704   55   61]
 [ 106   31 1540  223]
 [ 100   53  256 1491]]

--- Model: svd ---


                                                               

Train Metrics:
Accuracy : 0.6615
Precision: 0.6806
Recall   : 0.6615
F1 Score : 0.6616
Confusion Matrix:
[[18618  3170  2075  6137]
 [ 2342 23636   301  3721]
 [ 3527  1211 15257 10005]
 [ 2913  1961  3256 21870]]


                                                             


Test Metrics:
Accuracy : 0.6438
Precision: 0.6588
Recall   : 0.6438
F1 Score : 0.6427
Confusion Matrix:
[[1116  239  153  392]
 [ 169 1494   15  222]
 [ 236   87  940  637]
 [ 176  147  234 1343]]
