In [None]:
import os
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from transformers import DistilBertModel, DistilBertTokenizer
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


# =========================
# Model
# =========================
class MultiModalModel(nn.Module):
    def __init__(self, num_classes):
        super(MultiModalModel, self).__init__()

        # -------- Image Branch --------
        self.image_model = models.resnet18(weights="IMAGENET1K_V1")
        self.image_model.fc = nn.Identity()

        # Freeze ResNet
        for param in self.image_model.parameters():
            param.requires_grad = False

        # -------- Text Branch --------
        self.text_model = DistilBertModel.from_pretrained("distilbert-base-uncased")

        # Freeze DistilBERT
        for param in self.text_model.parameters():
            param.requires_grad = False

        # -------- Fusion Classifier --------
        self.classifier = nn.Sequential(
            nn.Linear(512 + 768, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, image, input_ids, attention_mask):

        img_features = self.image_model(image)

        text_outputs = self.text_model(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        text_features = text_outputs.last_hidden_state[:, 0, :]

        combined = torch.cat((img_features, text_features), dim=1)
        output = self.classifier(combined)

        return output


# =========================
# Dataset
# =========================
class GarbageDataset(Dataset):
    def __init__(self, root_dir, tokenizer, transform=None, max_length=32):
        self.root_dir = root_dir
        self.transform = transform
        self.tokenizer = tokenizer
        self.max_length = max_length

        self.classes = sorted(os.listdir(root_dir))
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}

        self.samples = []

        for cls in self.classes:
            cls_path = os.path.join(root_dir, cls)
            for file in os.listdir(cls_path):
                if file.endswith(".png") or file.endswith(".jpg"):
                    self.samples.append((os.path.join(cls_path, file), self.class_to_idx[cls]))

    def __len__(self):
        return len(self.samples)

    def extract_text_from_filename(self, filepath):
        filename = os.path.basename(filepath)
        name = os.path.splitext(filename)[0]

        parts = name.split("_")
        if parts[-1].isdigit():
            parts = parts[:-1]

        text = " ".join(parts)
        return text

    def __getitem__(self, idx):
        image_path, label = self.samples[idx]

        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        text = self.extract_text_from_filename(image_path)

        encoding = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        input_ids = encoding["input_ids"].squeeze(0)
        attention_mask = encoding["attention_mask"].squeeze(0)

        return image, input_ids, attention_mask, label


# =========================
# Training Functions
# =========================
def train_one_epoch(model, loader, optimizer, criterion):
    model.train()

    total_loss = 0
    correct = 0
    total = 0

    loop = tqdm(loader)

    for images, input_ids, attention_mask, labels in loop:

        images = images.to(device)
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(images, input_ids, attention_mask)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

        current_acc = 100 * correct / total
        loop.set_postfix(loss=loss.item(), acc=current_acc)

    return total_loss / len(loader), 100 * correct / total


def evaluate(model, loader, criterion):
    model.eval()

    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, input_ids, attention_mask, labels in loader:

            images = images.to(device)
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            labels = labels.to(device)

            outputs = model(images, input_ids, attention_mask)
            loss = criterion(outputs, labels)

            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    return total_loss / len(loader), 100 * correct / total


# =========================
# Main
# =========================
if __name__ == "__main__":

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])

    tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

    train_dataset = GarbageDataset(
        root_dir="garbage_data/CVPR_2024_dataset_Train",
        tokenizer=tokenizer,
        transform=transform
    )

    val_dataset = GarbageDataset(
        root_dir="garbage_data/CVPR_2024_dataset_Val",
        tokenizer=tokenizer,
        transform=transform
    )

    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=2)

    model = MultiModalModel(num_classes=4).to(device)

    criterion = nn.CrossEntropyLoss()

    # Only train classifier parameters
    optimizer = torch.optim.Adam(model.classifier.parameters(), lr=1e-3)

    num_epochs = 5
    best_val_acc = 0

    for epoch in range(num_epochs):

        print(f"\n===== Epoch {epoch+1}/{num_epochs} =====")

        train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion)
        val_loss, val_acc = evaluate(model, val_loader, criterion)

        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss:   {val_loss:.4f}, Val Acc:   {val_acc:.2f}%")

        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), "best_model.pth")
            print("Best model saved.")
