In [1]:
import torch
import torch.nn as nn
from torch.nn import Linear
import torch.nn.functional as F
import pickle
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from pytorch_metric_learning.losses import NTXentLoss
import pandas as pd
import numpy as np

In [2]:
with open('cls_emb.pkl', 'rb') as f:
    cls = pickle.load(f)
with open('feature_vectors.pkl', 'rb')as f:
    feature_vectors= pickle.load(f)

In [3]:
cls[0].size()

torch.Size([1, 768])

In [4]:
response_df = pd.read_csv('final_data.csv')
map_dict = {'llama3.1-70b':0, 'mistral':1, 'gpt-4o-2024-05-13':2}
response_df['model_nums'] = response_df['model'].map(map_dict)

In [5]:
embeddings = [torch.cat((cls[i].float(), torch.from_numpy(feature_vectors[i]).unsqueeze(0).float()), dim=1) for i in range(len(cls))]

In [6]:
def extract_and_split(response_df, embeddings, temperature):
    temp_idx = response_df[response_df['temperature'] == temperature].index
    temp_embs = [embeddings[idx] for idx in temp_idx]
    temp_targs = [response_df['model_nums'][idx] for idx in temp_idx]
    
    return train_test_split(temp_embs, temp_targs, test_size=0.1, random_state=42)
    
temp_0_train, temp_0_test, temp_0_targs_train, temp_0_targs_test = extract_and_split(response_df, embeddings, 0)
temp_7_train, temp_7_test, temp_7_targs_train, temp_7_targs_test = extract_and_split(response_df, embeddings, 0.7)
temp_14_train, temp_14_test, temp_14_targs_train, temp_14_targs_test = extract_and_split(response_df, embeddings, 1.4)
temp_all_train, temp_all_test, temp_all_targs_train, temp_all_targs_test = train_test_split(embeddings, response_df['model_nums'], 
                                                                                            test_size=0.1, random_state=42)

In [7]:
class FAM(nn.Module):
    def __init__(self, embed_size, hidden_size, hidden_dropout_prob):
        super().__init__()
        self.dropout = nn.Dropout(hidden_dropout_prob)
        self.fc = nn.Linear(embed_size, hidden_size)
        
    def init_weights(self):
        initrange = 0.2
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()


    def forward(self, text):
        batch,  dim = text.size()
        feat = self.fc(torch.tanh(self.dropout(text.view(batch, dim))))
        feat = F.normalize(feat, dim=1)
        return feat

In [8]:
class Projection(nn.Module):
    def __init__(self, hidden_size, projection_size):
        super().__init__()
        self.fc = nn.Linear(hidden_size, projection_size)
        self.ln = nn.LayerNorm(projection_size)
        self.bn = nn.BatchNorm1d(projection_size)
        self.init_weights()
    def init_weights(self):
        initrange = 0.01
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()


    def forward(self, text):
        batch,  dim = text.size()
        return self.ln(self.fc(torch.tanh(text.view(batch, dim))))

In [30]:
class SupConLoss(nn.Module):
    def __init__(self, temperature=0.07):
        """
        Implementation of the loss described in the paper Supervised Contrastive Learning :
        https://arxiv.org/abs/2004.11362

        :param temperature: int
        """
        super(SupConLoss, self).__init__()
        self.temperature = temperature

    def forward(self, projections, targets):
        """

        :param projections: torch.Tensor, shape [batch_size, projection_dim]
        :param targets: torch.Tensor, shape [batch_size]
        :return: torch.Tensor, scalar
        """
        device = torch.device("cuda") if projections.is_cuda else torch.device("cpu")

        dot_product_tempered = torch.mm(projections, projections.T) / self.temperature
        # Minus max for numerical stability with exponential. Same done in cross entropy. Epsilon added to avoid log(0)
        exp_dot_tempered = (
            torch.exp(dot_product_tempered - torch.max(dot_product_tempered, dim=1, keepdim=True)[0]) + 1e-5
        )

        mask_similar_class = (targets.unsqueeze(1).repeat(1, targets.shape[0]) == targets).to(device)
        mask_anchor_out = (1 - torch.eye(exp_dot_tempered.shape[0])).to(device)
        mask_combined = mask_similar_class * mask_anchor_out
        cardinality_per_samples = torch.sum(mask_combined, dim=1)

        log_prob = -torch.log(exp_dot_tempered / (torch.sum(exp_dot_tempered * mask_anchor_out, dim=1, keepdim=True)))
        supervised_contrastive_loss_per_sample = torch.sum(log_prob * mask_combined, dim=1) / cardinality_per_samples
        supervised_contrastive_loss = torch.mean(supervised_contrastive_loss_per_sample)

        return supervised_contrastive_loss

In [10]:
class Classifier(nn.Module):
    def __init__(self, hidden_size, num_class, hidden_dropout_prob):
        super().__init__()
        self.dropout = nn.Dropout(hidden_dropout_prob)
        self.fc = nn.Linear(hidden_size, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.02
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, feature):
        return self.fc(torch.tanh(feature))

In [11]:
class WordEmbeddingDataset(Dataset):
    def __init__(self, cls_embs, targs):
        self.cls_embs = cls_embs
        self.targs = targs 

    def __len__(self):
        return len(self.cls_embs)

    def __getitem__(self, idx):
        return self.cls_embs[idx], self.targs[idx]

In [12]:
BATCH_SIZE = 100
dataset_0 = WordEmbeddingDataset(temp_0_train, temp_0_targs_train)
dataset_0_test = WordEmbeddingDataset(temp_0_test, temp_0_targs_test)

dataset_7 =  WordEmbeddingDataset(temp_7_train, temp_7_targs_train)
dataset_7_test = WordEmbeddingDataset(temp_7_test, temp_7_targs_test)

dataset_14 = WordEmbeddingDataset(temp_14_train, temp_14_targs_train)
dataset_14_test = WordEmbeddingDataset(temp_14_test, temp_14_targs_test)

dataset_all = WordEmbeddingDataset(temp_all_train, temp_all_targs_train)
dataset_all_test = WordEmbeddingDataset(temp_all_test, temp_all_targs_test)

data_loader_0 = DataLoader(dataset_0, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
data_loader_0_test = DataLoader(dataset_0_test, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

data_loader_7 = DataLoader(dataset_7, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
data_loader_7_test = DataLoader(dataset_7_test, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

data_loader_14 = DataLoader(dataset_14, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
data_loader_14_test = DataLoader(dataset_14_test, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

data_loader_all = DataLoader(dataset_all, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
data_loader_all_test = DataLoader(dataset_all_test, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

In [32]:
def train(fa_module, proj_module, supconloss_module, classifier, data_loader, optimizer, classifier_loss_fn):
    fa_module.train()
    proj_module.train()
    supconloss_module.train()
    classifier.train()

    batch_acc_cumulative = 0
    n_batches = 0
    train_loss = 0

    for _, data in tqdm(enumerate(data_loader)):
        n_batches += 1
        optimizer.zero_grad()

        
        cls_embs = data[0].squeeze(1)  # Assuming BERT CLS embeddings
        targets = data[1]
        
        fam_output = fa_module(cls_embs)   
        proj_output = proj_module(fam_output)
        supcon_loss = supconloss_module(proj_output, targets)
        classifier_output = classifier(fam_output)  
        classifier_loss = classifier_loss_fn(classifier_output, targets)

        loss = supcon_loss + classifier_loss 

        loss.backward()   # Backpropagate the combined loss
        optimizer.step()  # Update the model parameters

        
        train_loss += loss.item()

        
        batch_predictions = classifier_output.argmax(1)
        batch_acc = (batch_predictions == targets).sum().item() / len(targets)
        batch_acc_cumulative += batch_acc

    
    average_acc = batch_acc_cumulative / n_batches
    average_loss = train_loss / n_batches

    print(f'Average Accuracy: {average_acc * 100:.2f}%')

    return average_loss, average_acc

In [45]:
def evaluate(fa_module, classifier, data_loader):
    fa_module.eval()  
    classifier.eval()
    
    correct = 0
    total = 0
    with torch.no_grad():  
        for data in data_loader:
            cls_embs = data[0].squeeze(1)  
            targets = data[1].tolist()
            
            # Forward pass through FAM
            fam_output = fa_module(cls_embs)
            
            # Forward pass through Classifier
            final_output = classifier(fam_output)
            
            # Get predictions
            preds = final_output.argmax(1).tolist()
            
            total += len(preds) 
            correct += np.sum(np.array(preds) == np.array(targets))  # Count how many predictions are correct

    accuracy = correct / total if total > 0 else 0  # Avoid division by zero
    print(f'Test Accuracy: {accuracy:.4f}')
    return accuracy

In [46]:
fam_0 = FAM(797, 256, 0.3)
proj_0 = Projection(256, 128)
supcon_0 = SupConLoss()
classifier_0 = Classifier(256, 3, 0.3)

optimizer = torch.optim.Adam(list(fam_0.parameters()) + 
                             list(proj_0.parameters()) + 
                             list(classifier_0.parameters()), lr=0.001)
classifier_loss = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

# Training Loop
for epoch in range(1, 50):
    loss, acc = train(fam_0, proj_0, supcon_0, classifier_0, data_loader_0, optimizer, classifier_loss)  
    print(f'Epoch {epoch}, Loss: {loss:.4f}')
    scheduler.step()

# Evaluation
test_accuracy = evaluate(fam_0, classifier_0, data_loader_0_test)

43it [00:00, 146.51it/s]


Average Accuracy: 36.67%
Epoch 1, Loss: 5.6896


43it [00:00, 151.63it/s]


Average Accuracy: 39.58%
Epoch 2, Loss: 5.6754


43it [00:00, 158.62it/s]


Average Accuracy: 42.16%
Epoch 3, Loss: 5.6580


43it [00:00, 166.28it/s]


Average Accuracy: 45.16%
Epoch 4, Loss: 5.6342


43it [00:00, 154.07it/s]


Average Accuracy: 46.51%
Epoch 5, Loss: 5.6217


43it [00:00, 144.84it/s]


Average Accuracy: 47.23%
Epoch 6, Loss: 5.6184


43it [00:00, 168.32it/s]


Average Accuracy: 49.07%
Epoch 7, Loss: 5.6079


43it [00:00, 165.09it/s]


Average Accuracy: 48.93%
Epoch 8, Loss: 5.6061


43it [00:00, 168.26it/s]


Average Accuracy: 47.51%
Epoch 9, Loss: 5.6188


43it [00:00, 144.77it/s]


Average Accuracy: 48.56%
Epoch 10, Loss: 5.6103


43it [00:00, 149.00it/s]


Average Accuracy: 49.95%
Epoch 11, Loss: 5.5932


43it [00:00, 148.81it/s]


Average Accuracy: 47.44%
Epoch 12, Loss: 5.6063


43it [00:00, 162.52it/s]


Average Accuracy: 48.49%
Epoch 13, Loss: 5.6014


43it [00:00, 150.11it/s]


Average Accuracy: 47.84%
Epoch 14, Loss: 5.6136


43it [00:00, 134.67it/s]


Average Accuracy: 48.19%
Epoch 15, Loss: 5.5977


43it [00:00, 163.38it/s]


Average Accuracy: 49.91%
Epoch 16, Loss: 5.5880


43it [00:00, 162.94it/s]


Average Accuracy: 49.81%
Epoch 17, Loss: 5.5907


43it [00:00, 141.27it/s]


Average Accuracy: 50.98%
Epoch 18, Loss: 5.5858


43it [00:00, 161.80it/s]


Average Accuracy: 51.23%
Epoch 19, Loss: 5.5781


43it [00:00, 168.32it/s]


Average Accuracy: 49.81%
Epoch 20, Loss: 5.5897


43it [00:00, 162.04it/s]


Average Accuracy: 51.63%
Epoch 21, Loss: 5.5746


43it [00:00, 133.68it/s]


Average Accuracy: 52.09%
Epoch 22, Loss: 5.5753


43it [00:00, 138.30it/s]


Average Accuracy: 51.47%
Epoch 23, Loss: 5.5722


43it [00:00, 131.34it/s]


Average Accuracy: 50.60%
Epoch 24, Loss: 5.5823


43it [00:00, 142.29it/s]


Average Accuracy: 51.23%
Epoch 25, Loss: 5.5764


43it [00:00, 146.26it/s]


Average Accuracy: 51.07%
Epoch 26, Loss: 5.5756


43it [00:00, 163.79it/s]


Average Accuracy: 50.60%
Epoch 27, Loss: 5.5739


43it [00:00, 162.61it/s]


Average Accuracy: 52.35%
Epoch 28, Loss: 5.5619


43it [00:00, 143.77it/s]


Average Accuracy: 51.60%
Epoch 29, Loss: 5.5726


43it [00:00, 138.65it/s]


Average Accuracy: 51.77%
Epoch 30, Loss: 5.5690


43it [00:00, 158.02it/s]


Average Accuracy: 52.02%
Epoch 31, Loss: 5.5748


43it [00:00, 156.83it/s]


Average Accuracy: 51.28%
Epoch 32, Loss: 5.5705


43it [00:00, 141.22it/s]


Average Accuracy: 51.86%
Epoch 33, Loss: 5.5675


43it [00:00, 169.26it/s]


Average Accuracy: 51.51%
Epoch 34, Loss: 5.5693


43it [00:00, 169.65it/s]


Average Accuracy: 50.40%
Epoch 35, Loss: 5.5803


43it [00:00, 154.50it/s]


Average Accuracy: 51.84%
Epoch 36, Loss: 5.5712


43it [00:00, 139.21it/s]


Average Accuracy: 52.05%
Epoch 37, Loss: 5.5651


43it [00:00, 143.75it/s]


Average Accuracy: 51.70%
Epoch 38, Loss: 5.5655


43it [00:00, 160.23it/s]


Average Accuracy: 51.81%
Epoch 39, Loss: 5.5683


43it [00:00, 154.80it/s]


Average Accuracy: 51.95%
Epoch 40, Loss: 5.5682


43it [00:00, 165.63it/s]


Average Accuracy: 51.84%
Epoch 41, Loss: 5.5714


43it [00:00, 140.23it/s]


Average Accuracy: 50.37%
Epoch 42, Loss: 5.5744


43it [00:00, 142.04it/s]


Average Accuracy: 51.28%
Epoch 43, Loss: 5.5760


43it [00:00, 110.99it/s]


Average Accuracy: 51.88%
Epoch 44, Loss: 5.5639


43it [00:00, 106.14it/s]


Average Accuracy: 52.35%
Epoch 45, Loss: 5.5616


43it [00:00, 101.58it/s]


Average Accuracy: 52.51%
Epoch 46, Loss: 5.5618


43it [00:00, 104.21it/s]


Average Accuracy: 52.40%
Epoch 47, Loss: 5.5641


43it [00:00, 84.11it/s]


Average Accuracy: 52.56%
Epoch 48, Loss: 5.5595


43it [00:00, 99.24it/s] 

Average Accuracy: 52.05%
Epoch 49, Loss: 5.5647
Test Accuracy: 0.5825





In [None]:
# Model Instantiation
fam_0 = FAM(797, 256, 0.3)
proj_0 = Projection(256, 128)
supcon_0 = SupConLoss()
classifier_0 = Classifier(256, 3, 0.3)

# Optimizer only for models with parameters
optimizer = torch.optim.Adam(list(fam_0.parameters()) + 
                             list(proj_0.parameters()) + 
                             list(classifier_0.parameters()), lr=0.001)

# Correct instantiation of CrossEntropyLoss
classifier_loss = nn.CrossEntropyLoss()

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

# Training Loop
for epoch in range(1, 50):
    loss, acc = train(fam_0, proj_0, supcon_0, classifier_0, data_loader_0, optimizer, classifier_loss)  
    print(f'Epoch {epoch}, Loss: {loss:.4f}')
    scheduler.step()

# Evaluation
test_accuracy = evaluate(fam_0, classifier_0, data_loader_0_test)

In [None]:
fam_14 = FAM(797, 256, 0.3)
proj_14 = Projection(256, 128)
supcon_14 = SupConHead()
classifier_14 = Classifier(256, 3, 0.3)
optimizer = torch.optim.Adam(list(fam_14.parameters()) + 
                             list(proj_14.parameters()) + 
                             list(supcon_14.parameters()) + 
                             list(classifier_14.parameters()), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
for epoch in range(1, 20):
    loss, acc = train(fam_14, proj_14, supcon_14, classifier_14, data_loader_14)  
    print(f'Epoch {epoch}, Loss: {loss:.4f}')
    scheduler.step()
test_accuracy = evaluate(fam_14, proj_14, supcon_14, classifier_14, data_loader_14_test)

In [None]:
fam_all = FAM(768, 256, 0.3)
proj_all = Projection(256, 128)
supcon_all = SupConHead()
classifier_all = Classifier(256, 3, 0.3)
optimizer = torch.optim.Adam(list(fam_all.parameters()) + 
                             list(proj_all.parameters()) + 
                             list(supcon_all.parameters()) + 
                             list(classifier_all.parameters()), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
for epoch in range(1, 20):
    loss, acc = train(fam_all, proj_all, supcon_all, classifier_all, data_loader_all)  
    print(f'Epoch {epoch}, Loss: {loss:.4f}')
    scheduler.step()
test_accuracy = evaluate(fam_all, proj_all, supcon_all, classifier_all, data_loader_all_test)