In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizer, BertModel
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

In [None]:
data= pd.read_csv('/kaggle/input/suicide-data-paired-for-contrastive-learning/suicide_data.csv')
val_data=pd.read_csv('/kaggle/input/suicide-data-paired-for-contrastive-learning/test.csv').iloc[:10000]

In [None]:
val_data['class']=val_data['class'].map({'suicide':1,'non-suicide':0})
val_data.head()

In [None]:
data.head()

In [None]:
def contrastive_loss(embedding1, embedding2, label, margin=1.0):
    distance = F.pairwise_distance(embedding1, embedding2)
    loss = 0.5 * (label * distance.pow(2) + (1 - label) * F.relu(margin - distance).pow(2))
    return loss.mean()

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

Architecture

In [None]:
class encoding_latent(nn.Module):
    def __init__(self, model):
        super(encoding_latent, self).__init__()
        self.pre = model
        
        self.fc1 = nn.Linear(self.pre.config.hidden_size, 512)
        self.fc2 = nn.Linear(512,384 )
        self.fc3 = nn.Linear(384,256)
        # self.fc3 = nn.Linear(256, 64)
        self.dropout = nn.Dropout(0.3)
    
    def forward(self, input_ids, attention_mask):
        pre_output = self.pre(input_ids=input_ids, attention_mask=attention_mask)
        embedding = pre_output.last_hidden_state[:,0,:]
        
        hidden = F.relu(self.fc1(embedding))  
        hidden = self.dropout(hidden)
        hidden = F.relu(self.fc2(hidden))
        latent = F.relu(self.fc3(hidden))
        
        return latent

Training

In [None]:
class TextDataset(Dataset):
    def __init__(self, csv_file, tokenizer, max_length=128):
        self.data = pd.read_csv(csv_file)
        self.text_column1 = 'anchor'
        self.text_column2 = 'text'
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text1 = self.data.iloc[idx][self.text_column1]
        text2 = self.data.iloc[idx][self.text_column2]
        
        inputs1 = self.tokenizer(text1, padding='max_length', truncation=True, max_length=self.max_length, return_tensors="pt")
        inputs2 = self.tokenizer(text2, padding='max_length', truncation=True, max_length=self.max_length, return_tensors="pt")
        
        label=self.data.iloc[idx]['label']
        return inputs1, inputs2, label

In [None]:
dataset= TextDataset('/kaggle/input/suicide-data-paired-for-contrastive-learning/suicide_data.csv',tokenizer)
dataloader=DataLoader(dataset,batch_size=64,shuffle=True)

In [None]:
model_c =encoding_latent(model)
model_c=model_c.to('cuda')

In [None]:
model_c

In [None]:
n_epochs=7
opt=torch.optim.AdamW(model_c.parameters(),lr=0.0001)

In [None]:

for epoch in range(n_epochs):
    print(f"Epoch {epoch + 1}/{n_epochs}")
    epoch_loss = 0.0
    model_c.train()
    
    for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}/{n_epochs}"):
        inputs1, inputs2, label = batch
        label = label.to('cuda')
        input_id1, attention_mask1 = inputs1['input_ids'].squeeze(1), inputs1['attention_mask'].squeeze(1)
        input_id2, attention_mask2 = inputs2['input_ids'].squeeze(1), inputs2['attention_mask'].squeeze(1)
        input_id1, attention_mask1 = input_id1.to('cuda'), attention_mask1.to('cuda')
        input_id2, attention_mask2 = input_id2.to('cuda'), attention_mask2.to('cuda')
        embedding1 = model_c(input_id1, attention_mask1)
        embedding2 = model_c(input_id2, attention_mask2)
        opt.zero_grad()
        loss = contrastive_loss(embedding1, embedding2, label)
        loss.backward()
        opt.step()
        epoch_loss += loss.item()
    average_loss = epoch_loss / len(dataloader)
    print(f"Epoch {epoch + 1} --> Loss: {average_loss:.6f}")
            

In [None]:
torch.save(model_c,'vector_pooler_128.pth')