In [6]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
from torch.utils.data import DataLoader, Dataset
import gzip
import json
import pandas as pd
# Load pretrained model
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')




In [4]:
class MultiTaskModel(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        hidden_size = self.encoder.config.hidden_size

        # Regression head
        self.regression_head = nn.Linear(hidden_size, 1)

        # Classification head (assuming rating from 0 to 4)
        self.classification_head = nn.Linear(hidden_size, 5)

    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0]

        regression_output = self.regression_head(pooled_output).squeeze(-1)
        classification_output = self.classification_head(pooled_output)

        return regression_output, classification_output

class RedTeamDataset(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        encoding = tokenizer(
            row['transcript'],
            truncation=True,
            padding='max_length',
            max_length=512,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'harmlessness_score': torch.tensor(row['min_harmlessness_score_transcript'], dtype=torch.float),
            'rating': torch.tensor(int(row['rating']), dtype=torch.long)
        }


In [8]:
data = []
file_path = "data/red-team-attempts/red_team_attempts.jsonl.gz"
with gzip.open(file_path, 'rt', encoding='utf-8') as f:
    # skip first symbol '['
    f.read(1)
    
    buffer = ""
    for line in f:
        line = line.strip()
        if line == "]":  
            break
        if line.endswith(','):
            line = line[:-1] 
        
        buffer += line
        try:
            record = json.loads(buffer)
            data.append(record)
            buffer = ""  
        except json.JSONDecodeError:
        
            continue

df = pd.DataFrame(data)

In [12]:
# Instantiate model
model = MultiTaskModel('distilbert-base-uncased')

# DataLoader
dataset = RedTeamDataset(df)
loader = DataLoader(dataset, batch_size=16, shuffle=True)

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

# Loss functions
criterion_regression = nn.MSELoss()
criterion_classification = nn.CrossEntropyLoss()

# Training Loop
model.train()


MultiTaskModel(
  (encoder): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): Line

In [14]:
from tqdm import tqdm

model.train()

for epoch in range(3):  # Number of epochs
    total_loss = 0

    progress_bar = tqdm(loader, desc=f"Epoch {epoch+1}", unit="batch")

    for batch in progress_bar:
        optimizer.zero_grad()

        regression_output, classification_output = model(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask']
        )

        loss_regression = criterion_regression(regression_output, batch['harmlessness_score'])
        loss_classification = criterion_classification(classification_output, batch['rating'])

        # Total loss (weighted sum)
        loss = loss_regression * 0.5 + loss_classification * 0.5

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # update loss in progress bar
        progress_bar.set_postfix(loss=loss.item())

    avg_loss = total_loss / len(loader)
    print(f'Epoch {epoch+1} Average Loss: {avg_loss:.4f}')

Epoch 1:   0%|          | 4/2435 [00:38<6:30:50,  9.65s/batch, loss=1.62]


KeyboardInterrupt: 