In [None]:
!unzip text

In [None]:
# ================================ FINAL WORKING VERSION – 2199 ALIGNED SEGMENTS ================================
import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


# ================================ DATA LOADER (FIXED WITH INTERVAL ALIGNMENT) ================================
def load_mosi_text_only(
    text_path='text/CMU_MOSI_TimestampedWordVectors_1.1.csd',
    label_path='text/CMU_MOSI_Opinion_Labels.csd'
):
    print(f"Loading text from: {text_path}")
    text_data = {}  # video_id -> {'features': (W, 300), 'intervals': (W, 2)}

    with h5py.File(text_path, 'r') as f:
        root = f['glove_vectors']['data']
        for video_id in root.keys():
            group = root[video_id]
            if 'features' in group and 'intervals' in group:
                features = group['features'][()].astype(np.float32)
                intervals = group['intervals'][()].astype(np.float32)
                text_data[video_id] = {'features': features, 'intervals': intervals}
            else:
                print(f"Warning: Video {video_id} missing features or intervals. Skipping.")

    print(f"→ Loaded text for {len(text_data)} videos")

    # Load labels and segment intervals
    print(f"Loading labels from: {label_path}")
    label_data = {}  # video_id -> {'labels': (S,), 'intervals': (S, 2)}

    with h5py.File(label_path, 'r') as f:
        root = f['Opinion Segment Labels']['data']
        for video_id in root.keys():
            group = root[video_id]
            if 'features' not in group or 'intervals' not in group:
                print(f"Warning: Video {video_id} missing label features or intervals. Skipping.")
                continue
            raw_labels = group['features'][()]
            intervals = group['intervals'][()].astype(np.float32)

            if raw_labels.ndim > 1 and raw_labels.shape[-1] == 7:
                # Probability distribution to expected value
                scores = np.dot(raw_labels, np.arange(-3, 4)).astype(np.float32)
            else:
                scores = raw_labels.flatten().astype(np.float32)

            label_data[video_id] = {'labels': scores, 'intervals': intervals}

    print(f"→ Loaded labels for {len(label_data)} videos")
    total_labels = sum(len(v['labels']) for v in label_data.values())
    print(f"→ Total label segments: {total_labels}")

    # Align: For each segment, extract matching word features using intervals
    sequences = []
    labels = []

    for video_id in sorted(text_data.keys()):
        if video_id not in label_data:
            print(f"SKIP: {video_id} has text but no labels")
            continue

        text_feats = text_data[video_id]['features']
        text_ints = text_data[video_id]['intervals']
        seg_labels = label_data[video_id]['labels']
        seg_ints = label_data[video_id]['intervals']

        if len(seg_labels) != len(seg_ints):
            print(f"Error: Mismatch in {video_id} label count vs intervals. Skipping.")
            continue

        for i in range(len(seg_labels)):
            seg_start, seg_end = seg_ints[i]

            # Find words where word_start >= seg_start and word_end <= seg_end
            # Assuming intervals are sorted by start time
            start_idx = np.searchsorted(text_ints[:, 0], seg_start, side='left')
            end_idx = np.searchsorted(text_ints[:, 1], seg_end, side='right')

            seg_feat = text_feats[start_idx:end_idx]
            if len(seg_feat) == 0:
                seg_feat = np.zeros((1, 300), dtype=np.float32)
            else:
                # Trim any all-zero rows (unlikely for GloVe, but safe)
                nonzero = np.where(np.any(seg_feat != 0, axis=1))[0]
                if len(nonzero) > 0:
                    seg_feat = seg_feat[:nonzero[-1] + 1]

            sequences.append(seg_feat)
            labels.append(seg_labels[i])

    assert len(sequences) == len(labels), f"Final mismatch: {len(sequences)} seqs vs {len(labels)} labels"
    print(f"\nSUCCESS! Aligned {len(sequences)} segments")
    print(f"Label range: {min(labels):.2f} to {max(labels):.2f}")
    print(f"Average segment length: {np.mean([len(s) for s in sequences]):.1f} words")

    return sequences, np.array(labels, dtype=np.float32)


# ================================ DATASET & COLLATE ================================
class MOSIDataset(Dataset):
    def __init__(self, seqs, labs, max_len=50):
        self.seqs = seqs
        self.labs = labs
        self.max_len = max_len

    def __len__(self):
        return len(self.seqs)

    def __getitem__(self, i):
        x = self.seqs[i]
        if len(x) > self.max_len:
            x = x[:self.max_len]
        if len(x) < self.max_len:
            pad = np.zeros((self.max_len - len(x), 300), dtype=np.float32)
            x = np.concatenate([x, pad], axis=0)
        return torch.from_numpy(x), torch.tensor(self.labs[i], dtype=torch.float32)


def collate(batch):
    seqs, labs = zip(*batch)
    seqs = torch.stack(seqs)
    labs = torch.stack(labs).unsqueeze(1)
    lengths = torch.tensor([min(len(orig_seq), 50) for orig_seq in [s for s, _ in batch]])
    return seqs, labs, lengths


# ================================ MODEL (BiLSTM + Attention) ================================
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm = nn.LSTM(300, 256, num_layers=2, bidirectional=True, batch_first=True, dropout=0.5)
        self.attn = nn.Linear(512, 1)
        self.norm = nn.LayerNorm(512)
        self.head = nn.Sequential(
            nn.Dropout(0.5), nn.Linear(512, 128), nn.GELU(),
            nn.Dropout(0.3), nn.Linear(128, 1)
        )

    def forward(self, x, lens):
        packed = nn.utils.rnn.pack_padded_sequence(x, lens.cpu(), batch_first=True, enforce_sorted=False)
        out, _ = self.lstm(packed)
        out, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
        a = self.attn(out).squeeze(-1)
        a = a.masked_fill(torch.arange(x.size(1), device=x.device)[None, :] >= lens[:, None], -1e9)
        a = torch.softmax(a, dim=1)
        ctx = (out * a.unsqueeze(-1)).sum(1)
        ctx = self.norm(ctx)
        return self.head(ctx)


# ================================ TRAIN LOOP ================================
def train():
    print("Loading CMU-MOSI text data...")
    seqs, labs = load_mosi_text_only()

    # Stratified split (binary: positive vs negative/neutral)
    bin_lab = (labs >= 0).astype(int)
    tr_x, val_x, tr_y, val_y = train_test_split(
        seqs, labs, test_size=0.2, random_state=42, stratify=bin_lab
    )

    tr_ds = MOSIDataset(tr_x, tr_y, max_len=50)
    val_ds = MOSIDataset(val_x, val_y, max_len=50)

    tr_dl = DataLoader(tr_ds, batch_size=32, shuffle=True, collate_fn=collate, num_workers=2, pin_memory=True)
    val_dl = DataLoader(val_ds, batch_size=32, shuffle=False, collate_fn=collate, num_workers=2, pin_memory=True)

    model = Model().to(device)
    opt = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
    sch = optim.lr_scheduler.ReduceLROnPlateau(opt, 'min', factor=0.5, patience=7)
    loss_fn = nn.MSELoss()

    best_val_loss = float('inf')
    patience_counter = 0
    max_patience = 20

    print("\n" + "="*60)
    print("TRAINING STARTED")
    print("="*60)

    for epoch in range(1, 201):
        model.train()
        tr_loss = 0.0
        for x, y, l in tr_dl:
            x, y, l = x.to(device), y.to(device), l.to(device)
            opt.zero_grad()
            p = model(x, l)
            loss = loss_fn(p, y)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            tr_loss += loss.item()

        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for x, y, l in val_dl:
                x, y, l = x.to(device), y.to(device), l.to(device)
                p = model(x, l)
                val_loss += loss_fn(p, y).item()

        tr_loss /= len(tr_dl)
        val_loss /= len(val_dl)
        sch.step(val_loss)

        print(f"Epoch {epoch:03d} | Train MSE: {tr_loss:.4f} | Val MSE: {val_loss:.4f}")

        if val_loss < best_val_loss - 1e-4:
            best_val_loss = val_loss
            torch.save(model.state_dict(), "best_mosi_text_final.pth")
            print(f"   >>> BEST MODEL! Val MSE = {best_val_loss:.4f}")
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= max_patience:
                print(f"Early stopping triggered after {epoch} epochs.")
                break

    print(f"\nFINISHED! Best Validation MSE = {best_val_loss:.4f}")
    print("Model saved as: best_mosi_text_final.pth")


# ================================ RUN ================================
if __name__ == "__main__":
    train()