<a href="https://colab.research.google.com/github/MJMortensonWarwick/ADA/blob/main/Untitled34.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers datasets

import torch
from transformers import AutoTokenizer, AutoModel
from datasets import load_dataset
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim

# 1. Load Dataset and Tokenizer
dataset = load_dataset("imdb")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

tokenized_datasets = dataset.map(tokenize_function, batched=True)

# 2. Create Embeddings-Only Model with Custom Head

class CustomHeadModel(nn.Module):
    def __init__(self, embedding_dim, num_labels):
        super(CustomHeadModel, self).__init__()
        self.embedding_dim = embedding_dim
        self.linear1 = nn.Linear(embedding_dim, 64)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(64, num_labels)

    def forward(self, embeddings):
        x = self.linear1(embeddings)
        x = self.relu(x)
        return self.linear2(x)


model_embeddings = AutoModel.from_pretrained("bert-base-uncased")
embedding_dim = model_embeddings.config.hidden_size
num_labels = 2 # Binary classification (positive/negative)
custom_head = CustomHeadModel(embedding_dim, num_labels)

# 3.  Embeddings-only prediction (no fine-tuning)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_embeddings.to(device)
custom_head.to(device)

#Example inference (using only embeddings)
example_text = "This is a great movie!"
encoded_input = tokenizer(example_text, return_tensors="pt").to(device)
with torch.no_grad():
  embeddings = model_embeddings(**encoded_input).pooler_output
  logits = custom_head(embeddings)
  predictions = torch.argmax(logits, dim=1)
  print(f"Prediction for '{example_text}': {predictions.item()}") # 0 or 1

# 4. Fine-tuning the Model with the Custom Head

#Combine the model and the head
model_fine_tune = nn.Sequential(model_embeddings,custom_head)
model_fine_tune.to(device)

optimizer = optim.AdamW(model_fine_tune.parameters(), lr=5e-5)
criterion = nn.CrossEntropyLoss()

small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(100)) #use a small subset for demonstration

# Custom collate function to handle lists
def custom_collate(batch):
    # Convert lists to tensors
    input_ids = torch.tensor([item['input_ids'] for item in batch])
    attention_mask = torch.tensor([item['attention_mask'] for item in batch])
    labels = torch.tensor([item['label'] for item in batch])

    return {'input_ids': input_ids, 'attention_mask': attention_mask, 'label': labels}

train_dataloader = DataLoader(small_train_dataset, shuffle=True, batch_size=16, collate_fn=custom_collate) #use custom collate

model_fine_tune.train()
for epoch in range(3): #train for 3 epochs
  for batch in train_dataloader:
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    labels = torch.tensor(batch['label']).to(device)

    optimizer.zero_grad()
    # Pass input_ids and attention_mask to the first module in the sequence (model_embeddings)
    # and then pass its output to the next module (custom_head)
    outputs = model_fine_tune[1](model_fine_tune[0](input_ids=input_ids, attention_mask=attention_mask).pooler_output)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

  print(f"Epoch {epoch+1} complete")

example_text = "This is a terrible movie!"
encoded_input = tokenizer(example_text, return_tensors="pt").to(device)
with torch.no_grad():
    # Pass input_ids and attention_mask to the first module in the sequence (model_embeddings)
    # and then pass its output to the next module (custom_head)
    outputs = model_fine_tune[1](model_fine_tune[0](**encoded_input).pooler_output) # Pass encoded_input as keyword arguments
    predictions = torch.argmax(outputs, dim=1)
    print(f"Prediction for '{example_text}': {predictions.item()}")
