In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import RobertaTokenizer, RobertaModel, CLIPProcessor, CLIPModel
from transformers import WhisperProcessor, WhisperModel
from sklearn.metrics import classification_report
import numpy as np
import torchaudio
import cv2
from PIL import Image
import random

In [None]:
# ------------------ Mount Google Drive ------------------
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# ------------------ Set Checkpoint Directory ------------------
CHECKPOINT_DIR = "/content/drive/MyDrive/MultimodalModelCheckpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

In [None]:
# ------------------ Set Seed ------------------
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

In [None]:
# ------------------ Preprocessing Helpers ------------------
def extract_audio_embedding(audio_path, device='cpu', duration=30):
    # Load model and processor
    processor = WhisperProcessor.from_pretrained("openai/whisper-base")
    model = WhisperModel.from_pretrained("openai/whisper-base").to(device)

    # Load the audio file
    waveform, sr = torchaudio.load(audio_path)

    # Ensure mono audio (single channel)
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)  # Convert stereo to mono by averaging channels

    # Trim to the first 30 seconds if necessary
    num_samples = int(duration * sr)
    waveform = waveform[:, :num_samples]

    # Resample to 16 kHz if necessary
    if sr != 16000:
        waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)

    # Process the waveform, ensuring the batch size is 1
    inputs = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt", padding=True)

    # The input features shape may need to be manually adjusted.
    # Whisper expects shape [batch_size, num_channels, sequence_length] where `num_channels` is the feature dimension.
    input_features = inputs['input_features']

    # Ensure it's a 3D tensor in the form [batch_size, num_channels, sequence_length]
    if input_features.ndimension() == 2:
        input_features = input_features.unsqueeze(0)  # Add a batch dimension

    # Use only the encoder to obtain embeddings
    encoder = model.get_encoder()

    # Run the encoder and get the embeddings
    with torch.no_grad():
        encoder_outputs = encoder(input_features=input_features.to(device))

    # Average over all time steps to get the entire context
    final_hidden_state = encoder_outputs.last_hidden_state.mean(dim=1)  # [1, 512] - average over all time steps

    return final_hidden_state  # This should return a tensor of shape [1, 512]

# def extract_frame(video_path, timestamp_str):
#     print(video_path)
#     time_in_seconds = float(timestamp_str)
#     cap = cv2.VideoCapture(video_path)
#     fps = cap.get(cv2.CAP_PROP_FPS)

#     if not cap.isOpened() or fps == 0:
#         raise ValueError("Error opening video file or unable to retrieve FPS.")

#     frame_number = int(fps * time_in_seconds)
#     cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
#     success, frame = cap.read()
#     cap.release()

#     if not success or frame is None:
#         raise ValueError(f"Could not extract frame at {timestamp_str} seconds")

#     frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
#     return Image.fromarray(frame)

# def get_image_embedding(image, clip_processor, clip_model, device='cpu'):
#     inputs = clip_processor(images=image, return_tensors="pt").to(device)
#     with torch.no_grad():
#         image_embed = clip_model.get_image_features(**inputs)
#     return image_embed.squeeze(0)  # [H_img]

In [None]:
from torch.utils.data import Dataset
import torch
import torchaudio
from transformers import WhisperProcessor, WhisperModel

class TokenAudioDataset(Dataset):
    def __init__(self, texts, audio_paths, labels, tokenizer, max_len, device):
        self.texts = texts
        self.audio_paths = audio_paths
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.device = device

        # Whisper model and processor (shared for efficiency)
        self.whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-base")
        self.whisper_model = WhisperModel.from_pretrained("openai/whisper-base").to(device).eval()

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        audio_path = self.audio_paths[idx]
        label = self.labels[idx]

        # Tokenize text
        encoded = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_len,
            return_tensors='pt',
            is_split_into_words=True,
            return_attention_mask=True
        )
        input_ids = encoded['input_ids'].squeeze(0)
        attention_mask = encoded['attention_mask'].squeeze(0)

        # Audio Embedding
        audio_embedding = self.extract_audio_embedding(audio_path)
        audio_embedding = audio_embedding.unsqueeze(0).repeat(self.max_len, 1)  # [T, H_a]

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'audio_embedding': audio_embedding,
            'labels': torch.tensor(label[:self.max_len], dtype=torch.long)
        }

    def extract_audio_embedding(self, audio_path, duration=30):
        waveform, sr = torchaudio.load(audio_path)

        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)

        # Trim and resample
        num_samples = int(duration * sr)
        waveform = waveform[:, :num_samples]

        if sr != 16000:
            waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)

        inputs = self.whisper_processor(
            waveform.squeeze().numpy(),
            sampling_rate=16000,
            return_tensors="pt",
            padding=True
        )

        input_features = inputs['input_features']

        # Pad or truncate to consistent length
        if input_features.shape[-1] < 3000:
            pad_len = 3000 - input_features.shape[-1]
            input_features = torch.nn.functional.pad(input_features, (0, pad_len), mode='constant', value=0)
        elif input_features.shape[-1] > 3000:
            input_features = input_features[:, :, :3000]

        if input_features.ndimension() == 2:
            input_features = input_features.unsqueeze(0)

        encoder = self.whisper_model.get_encoder()
        with torch.no_grad():
            encoder_outputs = encoder(input_features=input_features.to(self.device))

        return encoder_outputs.last_hidden_state.mean(dim=1).squeeze(0)  # [512]


In [None]:
# from torch.utils.data import Dataset
# import torch
# import torchaudio
# from transformers import WhisperProcessor, WhisperModel
# import os

# class TokenAudioVisualDataset(Dataset):
#     def __init__(self, texts, timestamps, audio_paths, video_paths, labels, tokenizer, clip_processor, clip_model, max_len, device):
#         self.texts = texts
#         self.timestamps = timestamps
#         self.audio_paths = audio_paths
#         self.video_paths = video_paths
#         self.labels = labels
#         self.tokenizer = tokenizer
#         self.max_len = max_len
#         self.clip_processor = clip_processor
#         self.clip_model = clip_model
#         self.device = device

#         # Whisper model and processor (shared for efficiency)
#         self.whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-base")
#         self.whisper_model = WhisperModel.from_pretrained("openai/whisper-base").to(device).eval()

#     def __len__(self):
#         return len(self.texts)

#     def __getitem__(self, idx):
#         text = self.texts[idx]
#         timestamps = self.timestamps[idx]
#         audio_path = self.audio_paths[idx]
#         video_path = self.video_paths[idx]
#         label = self.labels[idx]

#         # Tokenize text
#         encoded = self.tokenizer(
#             text,
#             truncation=True,
#             padding='max_length',
#             max_length=self.max_len,
#             return_tensors='pt',
#             is_split_into_words=True,
#             return_attention_mask=True
#         )
#         input_ids = encoded['input_ids'].squeeze(0)
#         attention_mask = encoded['attention_mask'].squeeze(0)

#         # Audio Embedding (lazy loading)
#         audio_embedding = self.extract_audio_embedding(audio_path)
#         audio_embedding = audio_embedding.unsqueeze(0).repeat(self.max_len, 1)  # [T, H_a]

#         # Video Embedding per timestamp
#         image_embeds = []
#         for ts in timestamps[:self.max_len]:
#             frame = extract_frame(video_path, ts)
#             img_embed = get_image_embedding(frame, self.clip_processor, self.clip_model, device=self.device)
#             image_embeds.append(img_embed)

#         # Pad if fewer frames than max_len
#         while len(image_embeds) < self.max_len:
#             image_embeds.append(torch.zeros_like(image_embeds[0]))
#         image_embeds = torch.stack(image_embeds, dim=0)  # [T, H_img]

#         return {
#             'input_ids': input_ids,
#             'attention_mask': attention_mask,
#             'audio_embedding': audio_embedding,
#             'image_embedding': image_embeds,
#             'labels': torch.tensor(label[:self.max_len], dtype=torch.long)
#         }

#     def extract_audio_embedding(self, audio_path, duration=30):
#         waveform, sr = torchaudio.load(audio_path)

#         if waveform.shape[0] > 1:
#             waveform = waveform.mean(dim=0, keepdim=True)

#         # Trim and resample
#         num_samples = int(duration * sr)
#         waveform = waveform[:, :num_samples]

#         if sr != 16000:
#             waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)

#         inputs = self.whisper_processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt", padding=True)
#         input_features = inputs['input_features']

#         input_features = inputs['input_features']
#         if input_features.shape[-1] < 3000:
#             pad_len = 3000 - input_features.shape[-1]
#             input_features = torch.nn.functional.pad(input_features, (0, pad_len), mode='constant', value=0)
#         elif input_features.shape[-1] > 3000:
#             input_features = input_features[:, :, :3000]  # truncate

#         if input_features.ndimension() == 2:
#             input_features = input_features.unsqueeze(0)

#         encoder = self.whisper_model.get_encoder()
#         with torch.no_grad():
#             encoder_outputs = encoder(input_features=input_features.to(self.device))
#         return encoder_outputs.last_hidden_state.mean(dim=1).squeeze(0)  # [512]

In [None]:
# # ------------------ Dataset ------------------
# class TokenAudioVisualDataset(Dataset):
#     def __init__(self, texts, timestamps, audio_path, video_path, labels, tokenizer, clip_processor, clip_model, max_len, device):
#         self.texts = texts
#         self.timestamps = timestamps
#         self.labels = labels
#         self.video_path = video_path
#         self.tokenizer = tokenizer
#         self.max_len = max_len
#         self.clip_processor = clip_processor
#         self.clip_model = clip_model
#         self.device = device

#         audio_embedding_seq = extract_audio_embedding(audio_path, device=device)
#         self.audio_embedding = audio_embedding_seq.mean(dim=0)

#     def __len__(self):
#         return len(self.texts)

#     def __getitem__(self, idx):
#         text = self.texts[idx]
#         timestamps = self.timestamps[idx]
#         label = self.labels[idx]

#         encoded = self.tokenizer(
#             text,
#             truncation=True,
#             padding='max_length',
#             max_length=self.max_len,
#             return_tensors='pt',
#             is_split_into_words=True,
#             return_attention_mask=True
#         )

#         input_ids = encoded['input_ids'].squeeze(0)
#         attention_mask = encoded['attention_mask'].squeeze(0)

#         audio_embed = self.audio_embedding
#         if audio_embed.dim() == 2:
#             audio_embed = audio_embed.mean(dim=0)
#         audio_embed = audio_embed.unsqueeze(0).repeat(self.max_len, 1)  # [T, H_a]

#         image_embeds = []
#         for ts in timestamps[:self.max_len]:
#             frame = extract_frame(self.video_path, ts)
#             img_embed = get_image_embedding(frame, self.clip_processor, self.clip_model, device=self.device)
#             image_embeds.append(img_embed)

#         while len(image_embeds) < self.max_len:
#             image_embeds.append(torch.zeros_like(image_embeds[0]))

#         image_embeds = torch.stack(image_embeds, dim=0)  # [T, H_img]

#         return {
#             'input_ids': input_ids,
#             'attention_mask': attention_mask,
#             'audio_embedding': audio_embed,
#             'image_embedding': image_embeds,
#             'labels': torch.tensor(label[:self.max_len], dtype=torch.long)
#         }

In [None]:
# ------------------ Model ------------------
# class RobertaAudioVisualClassifier(nn.Module):
#     def __init__(self, audio_dim=512, image_dim=512, hidden_size=256, num_labels=3):
#         super().__init__()
#         self.roberta = RobertaModel.from_pretrained("roberta-base")
#         self.classifier = nn.Sequential(
#             nn.Linear(self.roberta.config.hidden_size + audio_dim + image_dim, hidden_size),
#             nn.ReLU(),
#             nn.Dropout(0.1),
#             nn.Linear(hidden_size, num_labels)
#         )

#     def forward(self, input_ids, attention_mask, audio_embedding, image_embedding):
#         outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
#         roberta_out = outputs.last_hidden_state  # [B, T, H_text]
#         fused = torch.cat([roberta_out, audio_embedding, image_embedding], dim=-1)
#         logits = self.classifier(fused)
#         return logits

import torch
import torch.nn as nn
from transformers import RobertaModel

# ------------------ Model ------------------
class RobertaAudioClassifier(nn.Module):
    def __init__(self, audio_dim=512, hidden_size=256, num_labels=3):
        super().__init__()
        self.roberta = RobertaModel.from_pretrained("roberta-base")
        self.classifier = nn.Sequential(
            nn.Linear(self.roberta.config.hidden_size + audio_dim, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, num_labels)
        )

    def forward(self, input_ids, attention_mask, audio_embedding):
        outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
        roberta_out = outputs.last_hidden_state  # [B, T, H_text]
        fused = torch.cat([roberta_out, audio_embedding], dim=-1)  # [B, T, H_text + H_audio]
        logits = self.classifier(fused)  # [B, T, num_labels]
        return logits


In [None]:
# from tqdm import tqdm
# import torch
# import os
# from sklearn.metrics import classification_report

# # ------------------ Train & Eval ------------------
# def train(model, dataloader, optimizer, loss_fn, device, epoch=None, save_path=None):
#     model.train()
#     total_loss = 0
#     progress_bar = tqdm(dataloader, desc=f"Epoch {epoch}" if epoch is not None else "Training")

#     for batch in progress_bar:
#         input_ids = batch['input_ids'].to(device)
#         attention_mask = batch['attention_mask'].to(device)
#         audio_embedding = batch['audio_embedding'].to(device)
#         image_embedding = batch['image_embedding'].to(device)
#         labels = batch['labels'].to(device)

#         optimizer.zero_grad()
#         logits = model(input_ids, attention_mask, audio_embedding, image_embedding)

#         labels = labels[:len(logits)]
#         padding_len = logits.size(1) - labels.size(1)
#         if padding_len > 0:
#             pad = torch.full((1, padding_len), -100, dtype=labels.dtype, device=labels.device)
#             labels = torch.cat([labels, pad], dim=1)
#         else:
#             labels = labels[:, :logits.size(1)]

#         loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
#         loss.backward()
#         optimizer.step()
#         total_loss += loss.item()
#         progress_bar.set_postfix(loss=loss.item())

#     avg_loss = total_loss / len(dataloader)

#     if save_path and epoch is not None:
#         ckpt_path = os.path.join(save_path, f"model_epoch_{epoch}.pt")
#         torch.save(model.state_dict(), ckpt_path)
#         print(f"✅ Saved model at: {ckpt_path}")

#     return avg_loss

# def evaluate(model, dataloader, device):
#     model.eval()
#     all_preds, all_labels = [], []
#     progress_bar = tqdm(dataloader, desc="Evaluating")

#     with torch.no_grad():
#         for batch in progress_bar:
#             input_ids = batch['input_ids'].to(device)
#             attention_mask = batch['attention_mask'].to(device)
#             audio_embedding = batch['audio_embedding'].to(device)
#             image_embedding = batch['image_embedding'].to(device)
#             labels = batch['labels'].to(device)

#             logits = model(input_ids, attention_mask, audio_embedding, image_embedding)

#             labels = labels[:len(logits)]
#             padding_len = logits.size(1) - labels.size(1)
#             if padding_len > 0:
#                 pad = torch.full((1, padding_len), -100, dtype=labels.dtype, device=labels.device)
#                 labels = torch.cat([labels, pad], dim=1)
#             else:
#                 labels = labels[:, :logits.size(1)]

#             preds = torch.argmax(logits, dim=-1)
#             mask = attention_mask.view(-1) == 1
#             all_preds.extend(preds.view(-1)[mask].cpu().numpy())
#             all_labels.extend(labels.view(-1)[mask].cpu().numpy())

#     filtered_preds = []
#     filtered_labels = []
#     for p, l in zip(all_preds, all_labels):
#         if l != -100:
#             filtered_preds.append(p)
#             filtered_labels.append(l)

#     print(classification_report(filtered_labels, filtered_preds, target_names=["None", "Emotion", "Cause"], digits=4))

from tqdm import tqdm
import torch
import os
from sklearn.metrics import classification_report

# ------------------ Train & Eval ------------------
def train(model, dataloader, optimizer, loss_fn, device, epoch=None, save_path=None):
    model.train()
    total_loss = 0
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch}" if epoch is not None else "Training")

    for batch in progress_bar:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        audio_embedding = batch['audio_embedding'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        logits = model(input_ids, attention_mask, audio_embedding)

        labels = labels[:len(logits)]
        padding_len = logits.size(1) - labels.size(1)
        if padding_len > 0:
            pad = torch.full((labels.size(0), padding_len), -100, dtype=labels.dtype, device=labels.device)
            labels = torch.cat([labels, pad], dim=1)
        else:
            labels = labels[:, :logits.size(1)]

        loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())

    avg_loss = total_loss / len(dataloader)

    if save_path and epoch is not None:
        ckpt_path = os.path.join(save_path, f"model_epoch_{epoch}.pt")
        torch.save(model.state_dict(), ckpt_path)
        print(f"✅ Saved model at: {ckpt_path}")

    return avg_loss

def evaluate(model, dataloader, device):
    model.eval()
    all_preds, all_labels = [], []
    total_loss = 0.0  # To track the total loss
    progress_bar = tqdm(dataloader, desc="Evaluating")

    with torch.no_grad():
        for batch in progress_bar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            audio_embedding = batch['audio_embedding'].to(device)
            labels = batch['labels'].to(device)

            logits = model(input_ids, attention_mask, audio_embedding)

            labels = labels[:, :logits.size(1)]  # Truncate labels to match the logits size
            padding_len = logits.size(1) - labels.size(1)
            if padding_len > 0:
                pad = torch.full((labels.size(0), padding_len), -100, dtype=labels.dtype, device=labels.device)
                labels = torch.cat([labels, pad], dim=1)
            else:
                labels = labels[:, :logits.size(1)]

            # Calculate loss
            loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
            total_loss += loss.item()

            preds = torch.argmax(logits, dim=-1)
            mask = attention_mask.view(-1) == 1
            all_preds.extend(preds.view(-1)[mask].cpu().numpy())
            all_labels.extend(labels.view(-1)[mask].cpu().numpy())

    # Calculate the average validation loss
    avg_val_loss = total_loss / len(dataloader)

    # Filtering predictions and labels by removing padding (-100)
    filtered_preds = []
    filtered_labels = []
    for p, l in zip(all_preds, all_labels):
        if l != -100:
            filtered_preds.append(p)
            filtered_labels.append(l)

    # Print classification report
    print(classification_report(filtered_labels, filtered_preds, target_names=["None", "Emotion", "Cause"], digits=4))

    return avg_val_loss  # Return the average validation loss


In [None]:
# label_map = {'None': 0, 'Emotion': 1, 'Cause': 2}
# def parse_tokens(token_data):
#     texts = [item['token'] for item in token_data]
#     timestamps = [str(item['start']) for item in token_data]
#     labels = [label_map.get(item['label'] or 'None', 0) for item in token_data]
#     return texts, timestamps, labels

# from transformers import RobertaTokenizer
# import re
# import json

# # Load RoBERTa tokenizer
# tokenizer = RobertaTokenizer.from_pretrained("roberta-base")

# # Step 1: Parse the .txt file
# def parse_transcript_file(file_path):
#     with open(file_path, 'r', encoding='utf-8') as f:
#         lines = f.readlines()

#     full_text = ""
#     segments = []

#     # Extract full transcript text
#     text_started = False
#     for line in lines:
#         if line.strip().startswith("Transcript:"):
#             text_started = True
#             continue
#         if line.strip().startswith("Labeled Segments:"):
#             break
#         if text_started:
#             full_text += line.strip() + " "

#     # Extract labeled segments
#     segment_lines = lines[lines.index("Labeled Segments:\n") + 1:]
#     for line in segment_lines:
#         match = re.match(r'Time:\s*([\d.]+)-[\d.]+\s*sec\s*\|\s*Label:\s*(\w+|None)\s*\|\s*Text:\s*"(.*?)"', line)
#         if match:
#             start = float(match.group(1))
#             label = match.group(2) if match.group(2) != "None" else None
#             word = match.group(3)
#             segments.append((word, start, label))

#     return full_text.strip(), segments

# # Step 2: Tokenize and align
# def tokenize_and_align(text, segments):
#     tokens = []
#     seg_index = 0

#     for word, start, label in segments:
#         # Tokenize the word using RoBERTa
#         word_tokens = tokenizer.tokenize(word)
#         for tok in word_tokens:
#             tokens.append({
#                 'token': tok,
#                 'start': start,
#                 'label': label
#             })

#     return tokens

label_map = {'None': 0, 'Emotion': 1, 'Cause': 2}

def parse_tokens(token_data):
    texts = [item['token'] for item in token_data]
    timestamps = [str(item['start']) for item in token_data]
    labels = [label_map.get(item['label'] or 'None', 0) for item in token_data]
    return texts, timestamps, labels

from transformers import RobertaTokenizer
import re

# Load RoBERTa tokenizer
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")

# Step 1: Parse the .txt file
def parse_transcript_file(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()

    full_text = ""
    segments = []

    # Extract full transcript text
    text_started = False
    for line in lines:
        if line.strip().startswith("Transcript:"):
            text_started = True
            continue
        if line.strip().startswith("Labeled Segments:"):
            break
        if text_started:
            full_text += line.strip() + " "

    # Extract labeled segments
    segment_lines = lines[lines.index("Labeled Segments:\n") + 1:]
    for line in segment_lines:
        match = re.match(r'Time:\s*([\d.]+)-[\d.]+\s*sec\s*\|\s*Label:\s*(\w+|None)\s*\|\s*Text:\s*"(.*?)"', line)
        if match:
            start = float(match.group(1))
            label = match.group(2) if match.group(2) != "None" else None
            word = match.group(3)
            segments.append((word, start, label))

    return full_text.strip(), segments

# Step 2: Tokenize and align
def tokenize_and_align(text, segments):
    tokens = []
    for word, start, label in segments:
        word_tokens = tokenizer.tokenize(word)
        for tok in word_tokens:
            tokens.append({
                'token': tok,
                'start': start,
                'label': label
            })
    return tokens

In [None]:
from torch.nn.utils.rnn import pad_sequence

# Custom collate function
def collate_fn(batch):
    input_ids = [item['input_ids'] for item in batch]
    attention_mask = [item['attention_mask'] for item in batch]
    audio_embedding = [item['audio_embedding'] for item in batch]
    labels = [item['labels'] for item in batch]

    # Pad sequences to ensure all tensors have the same length
    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
    attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
    audio_embedding = pad_sequence(audio_embedding, batch_first=True, padding_value=0)
    labels = pad_sequence(labels, batch_first=True, padding_value=0)

    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'audio_embedding': audio_embedding,
        'labels': labels
    }

In [None]:
# import os, random
# from sklearn.model_selection import train_test_split

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
# clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
# clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# video_folder = "/content/drive/MyDrive/nlpProject/Videos"
# audio_folder = "/content/drive/MyDrive/nlpProject/Audios"
# txt_folder = "/content/drive/MyDrive/nlpProject/Done_text"

# triplets = []
# for fname in os.listdir(video_folder):
#     base = os.path.splitext(fname)[0]
#     v_path = os.path.join(video_folder, fname)
#     a_path = os.path.join(audio_folder, base + ".mp3")
#     t_path = os.path.join(txt_folder, base + ".txt")
#     if os.path.exists(a_path) and os.path.exists(t_path):
#         triplets.append((t_path, a_path, v_path))

# data = []
# for txt_path, audio_path, video_path in triplets:
#     text, segments = parse_transcript_file(txt_path)
#     token_data = tokenize_and_align(text, segments)
#     texts, timestamps, labels = parse_tokens(token_data)
#     data.append((texts, timestamps, labels, audio_path, video_path))

# train_data, val_test_data = train_test_split(data, test_size=0.3, random_state=42)
# val_data, test_data = train_test_split(val_test_data, test_size=0.5, random_state=42)

# def unpack_split(split_data):
#     texts, timestamps, labels, audio_paths, video_paths = zip(*split_data)
#     return list(texts), list(timestamps), list(labels), list(audio_paths), list(video_paths)

# train_texts, train_timestamps, train_labels, train_audios, train_videos = unpack_split(train_data)
# val_texts, val_timestamps, val_labels, val_audios, val_videos = unpack_split(val_data)
# test_texts, test_timestamps, test_labels, test_audios, test_videos = unpack_split(test_data)

# # train_dataset = TokenAudioVisualDataset(train_texts, train_timestamps, train_labels, train_audios, train_videos, tokenizer, clip_processor, clip_model, max_len=16, device=device)
# # val_dataset = TokenAudioVisualDataset(val_texts, val_timestamps, val_labels, val_audios, val_videos, tokenizer, clip_processor, clip_model, max_len=16, device=device)
# # test_dataset = TokenAudioVisualDataset(test_texts, test_timestamps, test_labels, test_audios, test_videos, tokenizer, clip_processor, clip_model, max_len=16, device=device)

# from torch.utils.data import DataLoader

# # Create the train, val, and test datasets
# train_dataset = TokenAudioVisualDataset(
#     train_texts, train_timestamps, train_audios, train_videos, train_labels,
#     tokenizer, clip_processor, clip_model, max_len=512, device=device
# )

# val_dataset = TokenAudioVisualDataset(
#     val_texts, val_timestamps, val_audios, val_videos, val_labels,
#     tokenizer, clip_processor, clip_model, max_len=512, device=device
# )

# test_dataset = TokenAudioVisualDataset(
#     test_texts, test_timestamps, test_audios, test_videos, test_labels,
#     tokenizer, clip_processor, clip_model, max_len=512, device=device
# )

# # Wrap datasets in dataloaders
# train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
# test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


import os, random
import torch
from sklearn.model_selection import train_test_split
from transformers import RobertaTokenizer

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")

# Data directories
audio_folder = "/content/drive/MyDrive/nlpProject/Audios"
txt_folder = "/content/drive/MyDrive/nlpProject/Done_text"

# Build triplets: only text + audio
triplets = []
for fname in os.listdir(txt_folder):
    base = os.path.splitext(fname)[0]
    t_path = os.path.join(txt_folder, fname)
    a_path = os.path.join(audio_folder, base + ".mp3")
    if os.path.exists(a_path):
        triplets.append((t_path, a_path))

# Process data
data = []
for txt_path, audio_path in triplets:
    text, segments = parse_transcript_file(txt_path)
    token_data = tokenize_and_align(text, segments)
    texts, timestamps, labels = parse_tokens(token_data)
    data.append((texts, timestamps, labels, audio_path))

# Train/val/test split
train_data, val_test_data = train_test_split(data, test_size=0.3, random_state=42)
val_data, test_data = train_test_split(val_test_data, test_size=0.5, random_state=42)

# Unpack utility
def unpack_split(split_data):
    texts, timestamps, labels, audio_paths = zip(*split_data)
    return list(texts), list(timestamps), list(labels), list(audio_paths)

# Unpack each split
train_texts, train_timestamps, train_labels, train_audios = unpack_split(train_data)
val_texts, val_timestamps, val_labels, val_audios = unpack_split(val_data)
test_texts, test_timestamps, test_labels, test_audios = unpack_split(test_data)

# Dataset class should now only handle text + audio
train_dataset = TokenAudioDataset(
    texts=train_texts,
    audio_paths=train_audios,
    labels=train_labels,
    tokenizer=tokenizer,
    max_len=512,
    device=device
)


# For validation dataset
val_dataset = TokenAudioDataset(
    texts=val_texts,
    audio_paths=val_audios,
    labels=val_labels,
    tokenizer=tokenizer,
    max_len=512,
    device=device
)

# For test dataset
test_dataset = TokenAudioDataset(
    texts=test_texts,
    audio_paths=test_audios,
    labels=test_labels,
    tokenizer=tokenizer,
    max_len=512,
    device=device
)


# DataLoaders
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)


In [None]:
# import os
# import glob
# import re
# import torch

# model = RobertaAudioVisualClassifier().to(device)

# # Find all checkpoint files like: model_epoch_*.pt
# checkpoint_files = sorted(
#     glob.glob(os.path.join(CHECKPOINT_DIR, "model_epoch_*.pt"))
# )

# latest_epoch = 0  # Default if no checkpoints found

# if checkpoint_files:
#     # Extract epoch numbers and find the latest one
#     epochs = [
#         int(re.search(r"model_epoch_(\d+)\.pt", f).group(1))
#         for f in checkpoint_files
#         if re.search(r"model_epoch_(\d+)\.pt", f)
#     ]
#     latest_epoch = max(epochs)
#     latest_ckpt_path = os.path.join(CHECKPOINT_DIR, f"model_epoch_{latest_epoch}.pt")

#     model.load_state_dict(torch.load(latest_ckpt_path, map_location=device))
#     print(f"✅ Loaded model from {latest_ckpt_path} (epoch {latest_epoch})")
# else:
#     print("⚠️ No checkpoint found, training from scratch.")


# optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
# loss_fn = nn.CrossEntropyLoss(ignore_index=-100, weight=torch.tensor([0.05, 0.475, 0.475], device=device))

# best_val_loss, patience, wait = float('inf'), 3, 0

import os
import glob
import re
import torch
import torch.nn as nn

# Initialize your text+audio model
model = RobertaAudioClassifier().to(device)  # Replace with your new class name

# Find all saved checkpoint files like: model_epoch_*.pt
checkpoint_files = sorted(
    glob.glob(os.path.join(CHECKPOINT_DIR, "model_epoch_*.pt"))
)

latest_epoch = 0  # Default if no checkpoint is found

if checkpoint_files:
    # Extract epoch numbers and find the latest
    epochs = [
        int(re.search(r"model_epoch_(\d+)\.pt", f).group(1))
        for f in checkpoint_files
        if re.search(r"model_epoch_(\d+)\.pt", f)
    ]
    latest_epoch = max(epochs)
    latest_ckpt_path = os.path.join(CHECKPOINT_DIR, f"model_epoch_{latest_epoch}.pt")

    model.load_state_dict(torch.load(latest_ckpt_path, map_location=device))
    print(f"✅ Loaded model from {latest_ckpt_path} (epoch {latest_epoch})")
else:
    print("⚠️ No checkpoint found, training from scratch.")

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

# Loss function with class weights (assumes same label mapping)
class_weights = torch.tensor([0.05, 0.475, 0.475], device=device)
loss_fn = nn.CrossEntropyLoss(ignore_index=-100, weight=class_weights)

# Early stopping params
best_val_loss = float('inf')
patience = 3
wait = 0

In [None]:
for epoch in range(20):
    train_loss = train(model, train_loader, optimizer, loss_fn, device, epoch=epoch, save_path=CHECKPOINT_DIR)
    val_loss = evaluate(model, val_loader, device)
    print(f"Epoch {epoch+1} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        wait = 0
        torch.save(model.state_dict(), "best_model.pt")
    else:
        wait += 1
        if wait >= patience:
            print("Early stopping.")
            break

print("Final evaluation on test set:")
evaluate(model, test_loader, device)