In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image, ImageFile
import pandas as pd
import os
import glob
from sklearn.model_selection import train_test_split
from torchvision.models import EfficientNet_B5_Weights
from tqdm import tqdm

# -------------------------
# 1. Load dataset
# -------------------------
df = pd.read_csv("mvsa_image_soft_labels.txt", sep="\t")  # or sep="," if comma
print("Original dataset shape:", df.shape)

# -------------------------
# 2. Robust Dataset (never crashes)
# -------------------------
ImageFile.LOAD_TRUNCATED_IMAGES = True

class SafeImageDataset(Dataset):
    def __init__(self, dataframe, image_dir, transform=None):
        self.transform = transform
        self.image_dir = image_dir
        self.label_map = {'negative': 0, 'neutral': 1, 'positive': 2}
        dataframe['ID'] = dataframe['ID'].astype(int).astype(str)
        self.data = dataframe.reset_index(drop=True)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        try:
            row = self.data.iloc[idx]
            img_id = str(row["ID"])
            img_path = None
            for ext in ["png", "jpg", "jpeg"]:
                temp_path = os.path.join(self.image_dir, f"{img_id}.{ext}")
                if os.path.exists(temp_path):
                    img_path = temp_path
                    break

            if img_path is None:
                raise FileNotFoundError

            image = Image.open(img_path).convert("RGB")
            if self.transform:
                image = self.transform(image)

            label_cell = row.iloc[1]
            label_str = str(label_cell)
            if ',' not in label_str:
                image_label = 'neutral'
            else:
                _, image_label = label_str.split(',')

            label = torch.tensor(self.label_map.get(image_label, 1), dtype=torch.long)

        except Exception:
            # On any error, return dummy image and neutral label
            image = torch.zeros(3, 456, 456)
            label = torch.tensor(1, dtype=torch.long)

        return image, label

# -------------------------
# 3. Train/Validation/Test split
# -------------------------
df['ID'] = df['ID'].astype(int).astype(str)
df_train, df_temp = train_test_split(df, test_size=0.2, random_state=42, stratify=df.iloc[:,1])
df_val, df_test = train_test_split(df_temp, test_size=0.5, random_state=42, stratify=df_temp.iloc[:,1])

print(f"Train: {len(df_train)}, Val: {len(df_val)}, Test: {len(df_test)}")

# -------------------------
# 4. Transforms
# -------------------------
weights = EfficientNet_B5_Weights.DEFAULT
transform = weights.transforms()  # includes resize, normalization

# -------------------------
# 5. Datasets and DataLoaders
# -------------------------
train_dataset = SafeImageDataset(df_train, "data", transform=transform)
val_dataset   = SafeImageDataset(df_val, "data", transform=transform)
test_dataset  = SafeImageDataset(df_test, "data", transform=transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=16, shuffle=False)

# -------------------------
# 6. Model
# -------------------------
num_classes = 3
model = models.efficientnet_b5(weights=weights)
model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# -------------------------
# 7. Loss and optimizer
# -------------------------
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# -------------------------
# 8. Training loop (fully crash-proof)
# -------------------------
epochs = 5

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    running_corrects = 0
    total_samples = 0
    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]", unit="batch")

    for batch in loop:
        try:
            images, labels = batch
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            _, preds = torch.max(outputs, 1)
            running_loss += loss.item() * images.size(0)
            running_corrects += (preds == labels).sum().item()
            total_samples += images.size(0)

            loop.set_postfix({
                'loss': f"{running_loss/total_samples:.4f}",
                'acc': f"{running_corrects/total_samples:.4f}"
            })

        except Exception:
            # Skip batch if something unexpected happens
            continue

    avg_train_loss = running_loss / max(total_samples, 1)
    train_acc = running_corrects / max(total_samples, 1)

    # --- Validation ---
    model.eval()
    val_loss = 0.0
    val_corrects = 0
    val_total = 0

    with torch.no_grad():
        for batch in val_loader:
            try:
                images, labels = batch
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)

                _, preds = torch.max(outputs, 1)
                val_loss += loss.item() * images.size(0)
                val_corrects += (preds == labels).sum().item()
                val_total += images.size(0)

            except Exception:
                continue

    avg_val_loss = val_loss / max(val_total, 1)
    val_acc = val_corrects / max(val_total, 1)
    print(f"Epoch {epoch+1}/{epochs} | Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.4f} | "
          f"Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.4f}")

print("Training finished!")

# -------------------------
# 9. Test evaluation (fully crash-proof)
# -------------------------
model.eval()
test_correct = 0
test_total = 0
with torch.no_grad():
    for batch in test_loader:
        try:
            images, labels = batch
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            test_correct += (preds == labels).sum().item()
            test_total += labels.size(0)
        except Exception:
            continue

test_acc = test_correct / max(test_total, 1)
print(f"Test Accuracy: {test_acc:.4f}")
