In [1]:
from multiprocessing.spawn import prepare
import os
import json
import torch

from datasets import load_dataset, Value
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from functools import partial


def read_jsonl(file_path):
    """
    Reads a JSONL (JSON Lines) file and returns a list of dictionaries.

    Parameters:
    file_path (str): The path to the JSONL file.

    Returns:
    list: A list of dictionaries where each dictionary represents a JSON object from the file.
    """
    data = []
    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            data.append(json.loads(line.strip()))
    return data

class utt_dataset(Dataset):
    def __init__(self, data, encoder_tokenizer, decoder_tokenizer, max_length=64, noiser=None, cse=False):
        self.data = data
        self.encoder_tokenizer = encoder_tokenizer
        self.decoder_tokenizer = decoder_tokenizer
        self.max_length = max_length
        self.noiser = noiser
        self.cse = cse
        
    def __getitem__(self, index):
        # Get the text data at the specified index
        if 'src' in self.data[index].keys():
            text = self.data[index]['src']
        elif 'text' in self.data[index].keys():
            text = self.data[index]['text']
        
        # Tokenize the input text using the encoder tokenizer
        encoder_input = self.encoder_tokenizer.encode_plus(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # Tokenize the input text using the decoder tokenizer
        decoder_input = self.decoder_tokenizer.encode_plus(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # Apply noise to the encoder input IDs if a noiser is provided
        if self.noiser:
            noisy_encoder_input_ids = self.noiser(encoder_input['input_ids'].squeeze())
        else:
            noisy_encoder_input_ids = encoder_input['input_ids'].squeeze()
        
        # Prepare the dictionary for returning the data
        x = {
            'text':text,
            'encoder_input_ids': noisy_encoder_input_ids,
            'decoder_input_ids': decoder_input['input_ids'].squeeze(),
            'attention_mask': encoder_input['attention_mask'].squeeze(),
            'decoder_attention_mask': decoder_input['attention_mask'].squeeze()
        }
        
        if self.cse: #more stuff for contrastive learning
            x['has_similar'] = 'similar' in self.data[index].keys()
            x['has_contrastive'] = 'contrastive' in self.data[index].keys()
            
            item = self.data[index]
            # Add similar sentences if they exist
            if 'similar' in item.keys():
                similar_input = self.encoder_tokenizer.encode_plus(
                    item['similar'],
                    padding='max_length',
                    max_length=self.max_length,
                    truncation=True,
                    return_tensors='pt'
                )
                x['similar_ids'] = similar_input['input_ids']
                x['similar_attention_mask'] = similar_input['attention_mask']

            # Add contrastive sentences if they exist
            if 'contrastive' in item.keys():
                contrastive_input = self.encoder_tokenizer.encode_plus(
                    item['contrastive'],
                    padding='max_length',
                    max_length=self.max_length,
                    truncation=True,
                    return_tensors='pt'
                )
                x['contrastive_ids'] = contrastive_input['input_ids']
                x['contrastive_attention_mask'] = contrastive_input['attention_mask']
        return x

    def __len__(self):
        return len(self.data)

def collate_fn(batch, encoder_tokenizer=None, decoder_tokenizer=None, cse=False):
    # Extract sequences from the batch
    encoder_input_ids = [item['encoder_input_ids'] for item in batch]
    decoder_input_ids = [item['decoder_input_ids'] for item in batch]
    attention_masks = [item['attention_mask'] for item in batch]
    decoder_attention_masks = [item['decoder_attention_mask'] for item in batch]
    
    # Pad sequences to the maximum length within the batch
    encoder_input_ids = torch.nn.utils.rnn.pad_sequence(encoder_input_ids, batch_first=True, padding_value=0)
    decoder_input_ids = torch.nn.utils.rnn.pad_sequence(decoder_input_ids, batch_first=True, padding_value=0)
    attention_masks = torch.nn.utils.rnn.pad_sequence(attention_masks, batch_first=True, padding_value=0)
    decoder_attention_masks = torch.nn.utils.rnn.pad_sequence(decoder_attention_masks, batch_first=True, padding_value=0)
    
    # Decode to text (ignoring the pad tokens)
    decoded_encoder_texts = [
        encoder_tokenizer.decode(ids, skip_special_tokens=True) 
        for ids in encoder_input_ids
    ]
    
    decoded_decoder_texts = [
        decoder_tokenizer.decode(ids, skip_special_tokens=True)
        for ids in decoder_input_ids
    ]
    
    labels = decoder_input_ids.clone()
    
    # Set padding tokens in labels to -100 to ignore them in the loss computation
    labels[labels == 0] = -100
    
    batch_dict = {
        'encoder_input_ids': encoder_input_ids,
        'decoder_input_ids': decoder_input_ids,
        'attention_mask': attention_masks,
        'decoder_attention_mask': decoder_attention_masks,
        'decoded_encoder_texts': decoded_encoder_texts,
        'decoded_decoder_texts': decoded_decoder_texts,
        'label_ids': labels
    }
    
    if cse:
        has_similar = torch.tensor([item['has_similar'] for item in batch])
        has_contrastive = torch.tensor([item['has_contrastive'] for item in batch])
                # Handle similar sentences if they exist
        if any(has_similar):
            similar_ids = [item['similar_ids'] for item in batch if 'similar_ids' in item]
            similar_attention_mask = [item['similar_attention_mask'] for item in batch if 'similar_attention_mask' in item]
            if similar_ids:
                batch_dict['similar_ids'] = torch.nn.utils.rnn.pad_sequence(similar_ids, batch_first=True, padding_value=0).squeeze(1)
                batch_dict['similar_attention_mask'] = torch.nn.utils.rnn.pad_sequence(similar_attention_mask, batch_first=True, padding_value=0).squeeze(1)

        # Handle contrastive sentences if they exist
        if any(has_contrastive):
            contrastive_ids = [item['contrastive_ids'] for item in batch if 'contrastive_ids' in item]
            contrastive_attention_mask = [item['contrastive_attention_mask'] for item in batch if 'contrastive_attention_mask' in item]
            if contrastive_ids:
                batch_dict['contrastive_ids'] = torch.nn.utils.rnn.pad_sequence(contrastive_ids, batch_first=True, padding_value=0).squeeze(1)
                batch_dict['contrastive_attention_mask'] = torch.nn.utils.rnn.pad_sequence(contrastive_attention_mask, batch_first=True, padding_value=0).squeeze(1)

    return batch_dict

def get_dataloader(data, encoder_tokenizer, decoder_tokenizer, batch_size=32, max_length=64, noiser=None, shuffle=True, cse=False):
    """
    Creates a DataLoader for the given dataset.

    Args:
    - data (list): List of datapoints, each being a dictionary with a 'text' field.
    - encoder_tokenizer (transformers.PreTrainedTokenizer): Tokenizer for the encoder.
    - decoder_tokenizer (transformers.PreTrainedTokenizer): Tokenizer for the decoder.
    - batch_size (int): Number of samples per batch to load.
    - max_length (int): Maximum sequence length for the tokenized inputs.
    - noiser (callable, optional): Function to add noise to the encoder input IDs.
    - shuffle (bool): Whether to shuffle the data at every epoch.

    Returns:
    - DataLoader: DataLoader for the dataset.
    """
    
    # Initialize the dataset
    dataset = utt_dataset(
        data=data,
        encoder_tokenizer=encoder_tokenizer,
        decoder_tokenizer=decoder_tokenizer,
        max_length=max_length,
        noiser=noiser,
        cse=cse
    )

    # Create the DataLoader
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        collate_fn=partial(collate_fn, encoder_tokenizer=encoder_tokenizer, decoder_tokenizer=decoder_tokenizer, cse=cse)
    )

    return dataloader

def get_utt_dataloader(file_path, encoder_tokenizer, decoder_tokenizer, max_length=64, noiser=None, batch_size=32, dev_mode=False, cse=False):
    
    if '.jsonl' in file_path:
        data = read_jsonl(file_path)
    else:
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
    
    if dev_mode:
        data=data[:100]
    
    dataloader = get_dataloader(data, encoder_tokenizer, decoder_tokenizer, 
                                max_length=max_length, noiser=noiser, batch_size=batch_size, cse=cse)
    return dataloader

  _torch_pytree._register_pytree_node(


In [2]:
from transformers import T5ForConditionalGeneration, AutoTokenizer

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


In [3]:
data = [
    {
        'input': 'How are you?',
        'output': 'I am fine.',
        'similar': 'How do you do?'
    },
    {
        'input': 'What is the weather?',
        'output': 'It is sunny.',
        'contrastive': 'I like pizza.'
    },
    {
        'input': 'Tell me a joke',
        'output': 'Why did the chicken cross the road?'
    }
]

tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-base')



In [4]:
# Create dataloader
dataloader = get_utt_dataloader('../datasets/nli_for_simcse_utt/dev.jsonl', tokenizer, tokenizer, dev_mode=True, cse=True)

In [5]:
for d in dataloader:
    print(d)

{'encoder_input_ids': tensor([[   37,   388,    19,  ...,     0,     0,     0],
        [ 8548,  7494, 17926,  ...,     0,     0,     0],
        [   71,  1021,  4940,  ...,     0,     0,     0],
        ...,
        [  438,  2937, 15618,  ...,     0,     0,     0],
        [   71,  1021,  3202,  ...,     0,     0,     0],
        [ 5245, 27424,     7,  ...,     0,     0,     0]]), 'decoder_input_ids': tensor([[   37,   388,    19,  ...,     0,     0,     0],
        [ 8548,  7494, 17926,  ...,     0,     0,     0],
        [   71,  1021,  4940,  ...,     0,     0,     0],
        ...,
        [  438,  2937, 15618,  ...,     0,     0,     0],
        [   71,  1021,  3202,  ...,     0,     0,     0],
        [ 5245, 27424,     7,  ...,     0,     0,     0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 

In [36]:
model = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')



In [121]:
org = model.get_encoder()(input_ids=d['encoder_input_ids'], attention_mask=d['attention_mask'])
sim = model.get_encoder()(input_ids=d['similar_ids'], attention_mask=d['similar_attention_mask'])
con = model.get_encoder()(input_ids=d['contrastive_ids'], attention_mask=d['contrastive_attention_mask'])

In [122]:
import torch.nn as nn
import torch.nn.functional as F
def get_simcse_loss(org, sim, con, pooler='avg', temp=0.05, hard_negative_weight=1):
    if sim is None and con is None:
        return 0
    
    cos_sim = nn.CosineSimilarity(dim=-1)
    batch_size = org.shape[0]
    
    # Pooling and normalization
    if pooler == 'avg':
        org_pooled = F.normalize(org.mean(dim=1), p=2, dim=-1)
        if sim is not None:
            sim_pooled = F.normalize(sim.mean(dim=1), p=2, dim=-1)
        if con is not None:
            con_pooled = F.normalize(con.mean(dim=1), p=2, dim=-1)
    else:
        raise NotImplementedError
    
    if sim_pooled is not None:
        # For supervised SimCSE, compare org with sim
        pos_cos = cos_sim(org_pooled.unsqueeze(1), sim_pooled.unsqueeze(0)) / temp
        
        if con_pooled is not None:
            # Add hard negatives if provided
            con_cos = cos_sim(org_pooled.unsqueeze(1), con_pooled.unsqueeze(0)) / temp
            base = torch.cat([pos_cos, con_cos], 1)
        else:
            base = pos_cos
            
        labels = torch.arange(batch_size).long().to(org.device)
        loss_fct = nn.CrossEntropyLoss()
        
        if con_pooled is not None:
            # Apply weights for hard negatives
            weights = torch.tensor(
                [[0.0] * (base.size(-1) - con_cos.size(-1)) + [0.0] * i + [hard_negative_weight] + [0.0] * (con_cos.size(-1) - i - 1) 
                 for i in range(con_cos.size(-1))]).to(org.device)
            base += weights
        
        loss = loss_fct(base, labels)

        return loss
    
    return 0

In [123]:
get_simcse_loss(org.last_hidden_state, org.last_hidden_state, con.last_hidden_state)

Org pooled norm: tensor([1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<LinalgVectorNormBackward0>)
Sim pooled norm: tensor([1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<LinalgVectorNormBackward0>)
Con pooled norm: tensor([1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<LinalgVectorNormBackward0>)
Positive cosine similarities:
tensor([[1.0000, 0.5376, 0.5670, 0.7013],
        [0.5376, 1.0000, 0.6180, 0.7106],
        [0.5670, 0.6180, 1.0000, 0.5380],
        [0.7013, 0.7106, 0.5380, 1.0000]], grad_fn=<MulBackward0>)
Negative cosine similarities:
tensor([[0.6207, 0.5539, 0.6196, 0.6437],
        [0.4293, 0.6174, 0.5574, 0.4936],
        [0.5422, 0.5650, 0.6543, 0.5008],
        [0.5024, 0.6307, 0.5809, 0.6021]], grad_fn=<MulBackward0>)


tensor(0.0055, grad_fn=<NllLossBackward0>)

In [60]:
d['encoder_input_ids'].shape

torch.Size([4, 64])

In [84]:
org.last_hidden_state.mean(dim=1).unsqueeze(1).shape

torch.Size([4, 1, 768])

In [86]:
con.last_hidden_state.mean(dim=1).unsqueeze(0).shape

torch.Size([1, 4, 768])

In [87]:
cosine=nn.CosineSimilarity(dim=-1)
cosine(org.last_hidden_state.mean(dim=1).unsqueeze(1), con.last_hidden_state.mean(dim=1).unsqueeze(0))

tensor([[0.6207, 0.5539, 0.6196, 0.6437],
        [0.4293, 0.6174, 0.5574, 0.4936],
        [0.5422, 0.5650, 0.6543, 0.5008],
        [0.5024, 0.6307, 0.5809, 0.6021]], grad_fn=<SumBackward1>)

In [118]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Create sample embeddings
batch_size = 3
seq_length = 4
hidden_dim = 8

# Create original embeddings
org = torch.randn(batch_size, seq_length, hidden_dim)
# Create similar embeddings (slightly perturbed version of org)
sim = org + 0.1 * torch.randn_like(org)
# Create contrasting embeddings (completely different)
con = torch.randn(batch_size, seq_length, hidden_dim)

def get_simcse_loss(org, sim, con, pooler='avg', temp=0.05, hard_negative_weight=1):
    if sim is None and con is None:
        return 0
    
    cos_sim = nn.CosineSimilarity(dim=-1)
    batch_size = org.shape[0]
    
    # Pooling and normalization
    if pooler == 'avg':
        org_pooled = F.normalize(org.mean(dim=1), p=2, dim=-1)
        if sim is not None:
            sim_pooled = F.normalize(sim.mean(dim=1), p=2, dim=-1)
        if con is not None:
            con_pooled = F.normalize(con.mean(dim=1), p=2, dim=-1)
    else:
        raise NotImplementedError
    
    print("\nDebug - Normalized vector norms:")
    print("org_pooled norms:", torch.norm(org_pooled, dim=1))  # Should all be 1
    print("sim_pooled norms:", torch.norm(sim_pooled, dim=1))  # Should all be 1
    print("con_pooled norms:", torch.norm(con_pooled, dim=1))  # Should all be 1
    
    if sim_pooled is not None:
        # Calculate similarity matrices
        pos_cos = cos_sim(org_pooled.unsqueeze(1), sim_pooled.unsqueeze(0)) 
        con_cos = cos_sim(org_pooled.unsqueeze(1), con_pooled.unsqueeze(0))
        
        print("\nDebug - Similarity matrices (before temperature scaling):")
        print("Positive similarities (org vs sim):")
        print(pos_cos * temp)  # Multiply by temp to see raw similarities
        print("\nNegative similarities (org vs con):")
        print(con_cos * temp)  # Multiply by temp to see raw similarities
        
        # Combine positive and negative similarities
        base = torch.cat([pos_cos, con_cos], 1)
        
        # Create labels (diagonal should have highest values)
        labels = torch.arange(batch_size).long()
        
        # Apply weights for hard negatives
        weights = torch.tensor(
            [[0.0] * (base.size(-1) - con_cos.size(-1)) + [0.0] * i + [hard_negative_weight] + [0.0] * (con_cos.size(-1) - i - 1) 
             for i in range(con_cos.size(-1))])
        
        base += weights
        
        print("\nDebug - Final logits matrix:")
        print(base * temp)  # Multiply by temp to see raw logits
        
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(base, labels)
        
        print("\nLoss:", loss.item())
        
        return loss
    
    return 0

# Test the function
print("Input shapes:")
print("org shape:", org.shape)
print("sim shape:", sim.shape)
print("con shape:", con.shape)

loss = get_simcse_loss(org, sim, con)

# Let's also verify the actual similarity between org and sim embeddings
# for a single example
print("\nVerification of a single example:")
example_idx = 0
org_example = F.normalize(org[example_idx].mean(dim=0), p=2, dim=-1)
sim_example = F.normalize(sim[example_idx].mean(dim=0), p=2, dim=-1)
con_example = F.normalize(con[example_idx].mean(dim=0), p=2, dim=-1)

print("Similarity between org and sim:", F.cosine_similarity(org_example, sim_example, dim=0).item())
print("Similarity between org and con:", F.cosine_similarity(org_example, con_example, dim=0).item())

Input shapes:
org shape: torch.Size([3, 4, 8])
sim shape: torch.Size([3, 4, 8])
con shape: torch.Size([3, 4, 8])

Debug - Normalized vector norms:
org_pooled norms: tensor([1.0000, 1.0000, 1.0000])
sim_pooled norms: tensor([1.0000, 1.0000, 1.0000])
con_pooled norms: tensor([1.0000, 1.0000, 1.0000])

Debug - Similarity matrices (before temperature scaling):
Positive similarities (org vs sim):
tensor([[0.0491, 0.0142, 0.0320],
        [0.0117, 0.0498, 0.0141],
        [0.0283, 0.0100, 0.0499]])

Negative similarities (org vs con):
tensor([[ 0.0386,  0.0020, -0.0196],
        [ 0.0071,  0.0023, -0.0100],
        [ 0.0326,  0.0115, -0.0061]])

Debug - Final logits matrix:
tensor([[ 0.0491,  0.0142,  0.0320,  0.0886,  0.0020, -0.0196],
        [ 0.0117,  0.0498,  0.0141,  0.0071,  0.0523, -0.0100],
        [ 0.0283,  0.0100,  0.0499,  0.0326,  0.0115,  0.0439]])

Loss: 1.4543529748916626

Verification of a single example:
Similarity between org and sim: 0.9819483757019043
Similarity between