### Install Dependencies

In [None]:
%%capture
# Installs Unsloth, Xformers (Flash Attention) and all other packages!
!pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu118
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install --no-deps "xformers<0.0.27" "trl<0.9.0" peft accelerate bitsandbytes
!pip install datasets scipy ipywidgets accelerate evaluate bert_score scikit-learn
!pip install wandb -qU
!pip install --upgrade transformers huggingface_hub
!pip install matplotlib==3.6.0 plotly

### Set-up Enironment Variables

In [1]:
from google.colab import userdata
import gc

### Load Dataset

In [2]:
from google.colab import drive
drive.mount('/content/drive')

In [3]:
from datasets import load_dataset
training_data = load_dataset("json", data_files="/content/drive/MyDrive/Thesis/Datasets/perspectrum_instruction_dataset_v2.jsonl", split = "train")

Generating train split: 0 examples [00:00, ? examples/s]

In [5]:
training_data = training_data.map(lambda x: {"entity_id": hash(x['Claim'])})
training_data[-1]

Map:   0%|          | 0/4450 [00:00<?, ? examples/s]

{'Claim': 'Net Neutrality – All Internet Traffic Should Be Treated Equally.',
 'Context': 'Supporting Perspectives:\n- Net neutrality is required to preserve the existing structure of the internet.\n- Net neutrality provides for the free circulation of data and services.\n- Net neutrality maintains a free market and even playing field.\n- Free speech is a right that should be allowed online.\n- Net neutrality preserves free speech on the internet by prohibiting internet service providers from blocking content.\n- Blocking content violates everyones right to free speech.\n- Net neutrality protects free speech from internet service provider threats.\n- Net neutrality like in democracy , will develop and preserve democracy and free speech.\n- Net neutrality in a democratic society will provide a free and open internet.\n- Net neutrality helps preserve democracy and free speech.\n- Net neutrality protects consumers by preventing ISPs from speeding, slowing, or charging higher fees for sele

In [6]:
len(training_data)

4450

### Load Model and Tokenizer

In [None]:
from unsloth import FastLanguageModel, is_bfloat16_supported
import torch
max_seq_length = 2048 # Choose any! Mistral auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
model_name = "unsloth/mistral-7b-instruct-v0.2-bnb-4bit"

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = model_name,
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)

In [None]:
# Building Lora enabled model
model = FastLanguageModel.get_peft_model(
    model,
    r = 32, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 64,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

### Create Custom Dataset and collator

In [11]:
import random, torch
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict

class trainDataset(Dataset):
    def __init__(self, data, batch_size):
        self.data = data
        self.batch_size = batch_size
        self.data_by_entity = defaultdict(list)
        for item in data:
            self.data_by_entity[item['entity_id']].append(item)

    def __len__(self):
        return len(self.data)

    def __deldata(self, *args):
        try:
            for var in args:
                del var
            torch.cuda.empty_cache()
            gc.collect()
        except Exception as e:
            print(f"Could not delete because {e}")

    def __getitem__(self, idx):
        anchor = self.data[idx]
        batch = [anchor]

        # Same entity, same polarity (1-2 samples)
        same_entity_same_polarity = [item for item in self.data_by_entity[anchor['entity_id']]
                                     if item['Polarity'] == anchor['Polarity'] and item != anchor]
        if same_entity_same_polarity:
            batch.extend(random.sample(same_entity_same_polarity, 2 if len(same_entity_same_polarity) > 1 else 1))

        # Same entity, different polarity (1-2 samples)
        same_entity_diff_polarity = [item for item in self.data_by_entity[anchor['entity_id']]
                                     if item['Polarity'] != anchor['Polarity']]
        if same_entity_diff_polarity:
          batch.extend(random.sample(same_entity_diff_polarity, 2 if len(same_entity_diff_polarity) > 1 else 1))

        # Different entities (fill the rest of the batch)
        other_entities = [entity for entity in self.data_by_entity.keys() if entity != anchor['entity_id']]
        rand_entities = random.sample(other_entities, self.batch_size - len(batch))
        for rand_entity in rand_entities:
            batch.append(random.choice(self.data_by_entity[rand_entity]))

        self.__deldata(same_entity_same_polarity, same_entity_diff_polarity, other_entities, rand_entities, anchor)

        random.shuffle(batch)
        return batch

In [12]:
System_prompt = "Below is an instruction that describes an information requirement, paired with a claim that provides context. Write a response that appropriately addresses the instruction based on the given claim."
Input_prompt = "### Instruction:\n{instruction}\n\n### Claim:\n{claim}\n\n### Response:\n{answer}"

In [13]:
def formatting_prompts_func(example):
    text = Input_prompt.format(instruction=example['Instruction'],
                               claim=example['Claim'],
                               answer=example['Answer']) + tokenizer.eos_token
    return f"{System_prompt}\n\n{text}"

In [14]:
def collate_fn(batch):
    act_batch = batch[0]

    inputs = [formatting_prompts_func(item) for item in act_batch]

    tokenized_inputs = tokenizer(inputs, padding='longest', return_tensors='pt')

    entity_ids = torch.tensor([item['entity_id'] for item in act_batch])
    polarities = torch.tensor([item['Polarity'] for item in act_batch])
    sample_no = torch.tensor([item['sample_no'] for item in act_batch])

    return {'input_ids': tokenized_inputs['input_ids'],
            'attention_mask': tokenized_inputs['attention_mask'],
            'labels': torch.stack([entity_ids, polarities, sample_no], dim=1)
            }

In [15]:
def prepare_dataloaders(training_data):
   train_dataset = trainDataset(training_data, batch_size=8)
   train_dataloader = DataLoader(train_dataset, batch_size=1, collate_fn=collate_fn, pin_memory=True, num_workers=4, shuffle=True, persistent_workers=True)
   eval_data_loader = DataLoader(training_data, batch_size=4, shuffle=True, pin_memory=True, num_workers=4)
   return train_dataloader, eval_data_loader

### Set up Projection head and Custom Loss

In [17]:
class SimpleProjectionHead(torch.nn.Module):
    def __init__(self, input_dim, projection_dim):
        super(SimpleProjectionHead, self).__init__()
        self.fc1 = torch.nn.Linear(input_dim, 256)
        self.fc2 = torch.nn.Linear(256, projection_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [18]:
import torch.nn.functional as F
class HierarchicalContrastiveLoss(torch.nn.Module):
    def __init__(self, temperature=0.07, entity_margin=1.0, polarity_margin=0.5):
        super().__init__()
        self.temperature = temperature
        self.entity_margin = entity_margin
        self.polarity_margin = polarity_margin

    def __deldata(self, *args):
        try:
            for var in args:
                del var
            torch.cuda.empty_cache()
            gc.collect()
        except Exception as e:
            print(f"Could not delete because {e}")

    def forward(self, embeddings, labels):
        _, seq_len, emb_dim = embeddings.shape
        embeddings = embeddings.reshape(-1, emb_dim)
        assert embeddings.dim() == 2, "Embeddings should be 2-dimensional"
        assert labels.dim() == 2 and labels.shape[1] == 2, "Labels should be 2-dimensional with shape (batch_size, 2)"

        embeddings = F.normalize(embeddings, dim=1) #1
        sim_matrix = torch.matmul(embeddings, embeddings.T) / self.temperature #1

        labels = labels.unsqueeze(1).repeat(1, seq_len, 1).reshape(-1, 2)

        entity_mask = (labels[:, 0].unsqueeze(0) == labels[:, 0].unsqueeze(1)).float()
        polarity_mask = (labels[:, 1].unsqueeze(0) == labels[:, 1].unsqueeze(1)).float()

        pos_mask = entity_mask * polarity_mask
        same_entity_neg_mask = entity_mask * (1 - polarity_mask)
        diff_entity_mask = 1 - entity_mask

        pos_loss = -torch.log(torch.exp(sim_matrix * pos_mask).sum(1) / torch.exp(sim_matrix).sum(1) + 1e-8).mean()
        same_entity_neg_loss = (torch.relu(sim_matrix - self.polarity_margin) * same_entity_neg_mask).sum() / (same_entity_neg_mask.sum() + 1e-8)
        diff_entity_neg_loss = (torch.relu(sim_matrix - self.entity_margin) * diff_entity_mask).sum() / (diff_entity_mask.sum() + 1e-8)

        total_loss = pos_loss + same_entity_neg_loss + diff_entity_neg_loss

        self.__deldata(embeddings, labels, entity_mask,
                       polarity_mask, pos_mask, same_entity_neg_mask,
                       diff_entity_mask, sim_matrix, pos_loss, same_entity_neg_loss,
                       diff_entity_neg_loss, seq_len, emb_dim)
        return total_loss

In [19]:
class EarlyStopping:
    def __init__(self, threshold=5, consecutive_high=3):
        self.threshold = threshold
        self.consecutive_high = consecutive_high
        self.high_count = 0
        self.margin = 3
        self.warmup_batch = 50
        self.last_score = float('inf')
        self.early_stop = False

    def __call__(self, score, batchnum):
        if batchnum < self.warmup_batch:
            print("Skipped")
            return

        print("check")
        # Check for sudden jump from 3-4 range to 4.5-6 range
        if score - self.last_score >= self.margin:
            self.early_stop = True
            return

        # Check if score is above threshold
        if score > self.threshold:
            self.high_count += 1
            if self.high_count >= self.consecutive_high:
                print("Max consecutives times threshold up")
                self.early_stop = True
                return
        else:
            self.high_count = 0

        self.last_score = score

### Set-up wandb logging

In [20]:
import wandb
class wandbStart():
    def __init__(self, login, project, name, notes, data):
        self.login = login
        self.project = project
        self.name = name
        self.notes = notes
        self.data = data
    def forward(self):
        wandb.login(key=self.login)
        wandb.init(project=self.project,
                   name=self.name,
                   tags=["training", "adapter", "Contrastive", "Epoch 1"],
                   notes=self.notes,
                   config=self.data)
        return wandb

### Check memory stats and print trainable parameters

In [21]:
import torch
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [22]:
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

GPU = NVIDIA A100-SXM4-40GB. Max memory = 39.394 GB.
4.668 GB of memory reserved.


In [23]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

In [24]:
print_trainable_parameters(model)

trainable params: 83886080 || all params: 3835957248 || trainable%: 2.186835633888691


### Set-up Accelarator for Multi-GPU Sharding and Distribution

In [25]:
from accelerate import FullyShardedDataParallelPlugin, Accelerator
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig

fsdp_plugin = FullyShardedDataParallelPlugin(
    state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
    optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False),
)

### Set Up Contrastive Learning Loop

In [26]:
from transformers import TrainingArguments, get_cosine_schedule_with_warmup
from datetime import datetime
model_id = "PerspectrumInstruct-Contrastive-InputSameAsLabels-Epochs_1-EarlyStop-Grad_Accum-Mistral7B"
training_arguments = TrainingArguments(
    output_dir= "./" + model_id,
    per_device_train_batch_size = 4,
    gradient_accumulation_steps = 4,
    gradient_checkpointing=True,
    warmup_steps = 5,
    num_train_epochs = 1,
    learning_rate = 2e-4,
    fp16 = not is_bfloat16_supported(),
    bf16 = is_bfloat16_supported(),
    logging_steps = 20,
    logging_dir="./logs",
    save_strategy="steps", # Save the model checkpoint every logging step
    save_steps=20,         # Save checkpoints every 50 steps
    eval_strategy="steps",
    eval_steps=20,         # Evaluate and save checkpoints every 50 steps
    do_eval=True,          # Perform evaluation at the end of training
    optim = "adamw_8bit",
    weight_decay = 0.01,
    lr_scheduler_type = "linear",
    seed = 3407,
    report_to="wandb",    # Comment this out if you don't want to use weights & baises
    run_name=f"{model_id}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}", # Name of the W&B run (optional)
    push_to_hub=True,
    hub_model_id = f"TonyStarkD99/{model_id}",
    hub_strategy = "end",
    hub_token = userdata.get("HF_TOKEN"), #your Huggingface token here
    greater_is_better=True,
    load_best_model_at_end=True
    )

In [27]:
def prepare_accelarator(model, optimizer, training_data):
   if torch.cuda.device_count() > 1: # If more than 1 GPU
      model.is_parallelizable = True
      model.model_parallel = True
   accelerator = Accelerator(fsdp_plugin=fsdp_plugin, mixed_precision='bf16')
   train_dataloader, eval_data_loader = prepare_dataloaders(training_data)
   return accelerator.prepare(model, optimizer, train_dataloader, eval_data_loader)

### Do Training Loop

In [28]:
from tqdm.auto import tqdm
from torch.cuda.amp import autocast, GradScaler
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

class trainModel():
   def __init__(self, training_data, model, tokenizer, train_args, device, max_seq_length, warmup_ratio):
      self.training_data=training_data
      self.model=model
      self.tokenizer=tokenizer
      self.train_args=train_args
      self.device=device
      self.max_seq_length=max_seq_length
      self.warmup_ratio = warmup_ratio
      self.early_stopping = EarlyStopping()
      self.scaler=GradScaler()
      self.wandb=wandbStart(userdata.get("WANDB_TOKEN"), #Your wandb access token here
                            "PerspectrumInstruct-Contrastive-FT-Unsloth_Mistral7B",
                            "InputSameAsLabels-CLType_SelfMask-OptimChg_Yes-Epochs_1-EarlyStop-Gradient_Accumulation",
                            "Fine-tuning Mitsral 7B on Perspectrum Instruction dataset with contrastive learning InputSameAsLabels-Epoch_1-EarlyStop-Gradient_Accumulation-Mistral7B",
                            self.train_args).forward()
      self.tsne = TSNE(n_components=2, random_state=42)
      self.projection_head = SimpleProjectionHead(self.model.config.hidden_size, 128).to(self.device)
      self.contrastive_loss_fn = HierarchicalContrastiveLoss().to(self.device)
      self.optimizer = torch.optim.AdamW(list(self.model.parameters()) + list(self.projection_head.parameters()), lr=self.train_args.learning_rate, betas=(0.9,0.99), eps=1e-5)
      self.model, self.optimizer, self.train_dataloader, self.eval_data_loader = prepare_accelarator(self.model, self.optimizer, self.training_data)
      self.scheduler = get_cosine_schedule_with_warmup(self.optimizer,
                                                       num_warmup_steps=int((len(self.train_dataloader)//self.train_args.gradient_accumulation_steps)* self.warmup_ratio),
                                                       num_training_steps=len(self.train_dataloader)//self.train_args.gradient_accumulation_steps)

   def __print_memory_stats(self):
      print(f"Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
      print(f"Cached: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")

   def __clear_data(self, *args):
      try:
         for var in args:
            del var
         torch.cuda.empty_cache()
         gc.collect()
      except Exception as e:
         print("C", f"Error deleting because of {e}")

   def __for_train(self):
      self.projection_head.train()
      FastLanguageModel.for_training(self.model)

   def __for_eval(self):
      self.projection_head.eval()
      FastLanguageModel.for_inference(self.model)

   def forward(self):
      self.__for_train()
      self.wandb.watch(self.model)
      for epoch in tqdm(range(self.train_args.num_train_epochs)):
         total_loss = 0
         for count, batch in enumerate(tqdm(self.train_dataloader)):
            input_ids = batch["input_ids"].to(self.device)
            attention_mask = batch["attention_mask"].to(self.device)
            other_labels = batch["labels"].to(self.device)

            with autocast():
               outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids, output_hidden_states=True)
               self.early_stopping(outputs.loss.item(), count+1)

               batch_projected_embeddings = self.projection_head(outputs.hidden_states[-1])

               contrastive_loss = self.contrastive_loss_fn(batch_projected_embeddings, other_labels[:, :2])

               sum_loss = outputs.loss + contrastive_loss
               self.wandb.log({"Generation_loss": outputs.loss.item(),
                               "Contrastive_loss": contrastive_loss.item()})
               if self.early_stopping.early_stop:
                  print(f"Early stopped due to loss shooting from {self.early_stopping.last_score} to {outputs.loss.item()}")
                  break
               self.scaler.scale(sum_loss).backward()

            total_loss += sum_loss.item()
            if (count+1)%self.train_args.gradient_accumulation_steps==0:
               self.scaler.step(self.optimizer)
               self.scaler.update()
               self.scheduler.step()
               self.optimizer.zero_grad(set_to_none=True)

            if (count+1)%50 == 0:
               print(tokenizer.batch_decode(outputs.logits.argmax(dim=-1))[0])

            self.__clear_data(input_ids, attention_mask, other_labels,
                              outputs, batch_projected_embeddings, sum_loss, contrastive_loss)

         avg_loss = total_loss / len(self.train_dataloader)

         self.wandb.log({f"Total Loss for epoch {epoch+1}": total_loss,
                         f"Average Loss for epoch {epoch+1}": avg_loss})

         self.__clear_data(avg_loss, total_loss)

   def saveModel(self):
      # if self.train_args.push_to_hub:
      self.model.push_to_hub(repo_id=self.train_args.hub_model_id, token=self.train_args.hub_token)
      self.tokenizer.push_to_hub(repo_id=self.train_args.hub_model_id, token=self.train_args.hub_token)

### Run and Save Model

In [None]:
warmup_rat = 0.01
trainer = trainModel(training_data, model, tokenizer, training_arguments, device, max_seq_length, warmup_rat)
trainer.forward()

In [None]:
trainer.saveModel()

### Evaluate Model after Loading from Hugging face.

In [None]:
table = wandb.Table(columns=["instruction", "claim", "response"])

In [None]:
FastLanguageModel.for_inference(model)

In [None]:
inst = 'Compare the opinion distributions for the following claim.'
clm = 'Vaccination must be made compulsory.'

In [None]:
# System_prompt and Input_prompt defined above
text = Input_prompt.format(instruction=inst,
                           claim=clm,
                           answer='',)
full_prompt = f"{System_prompt}\n\n{text}"
print(full_prompt)

inputs = tokenizer(
[
    full_prompt,
], return_tensors = "pt").to(torch.device("cuda"))

In [None]:
inputs

In [None]:
from transformers import TextStreamer
text_streamer = TextStreamer(tokenizer, skip_prompt = True)
output = model.generate(input_ids = inputs.input_ids, attention_mask = inputs.attention_mask,
                   streamer = text_streamer, max_new_tokens = 128, pad_token_id = tokenizer.eos_token_id)

In [None]:
gen_text = tokenizer.decode(output[0])

In [None]:
table.add_data(inst, clm, gen_text)

In [None]:
wandb.log({"Generations":table})

In [None]:
wandb.finish()