In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision import models, datasets, transforms
from torch.utils.data import Dataset, DataLoader

from transformers import DistilBertModel, DistilBertTokenizer

import os
import re
import numpy as np

# =========================
# CONFIG
# =========================

NUM_CLASSES = 4
BATCH_SIZE = 32
EPOCHS = 5
LEARNING_RATE = 0.001

DATA_DIR = "/work/TALC/ensf617_2026w/garbage_data"

TRAIN_PATH = r"C:\Users\john2\Desktop\uofc\617\assignment2\garbage_data\garbage_data\CVPR_2024_dataset_Train"
VAL_PATH = r"C:\Users\john2\Desktop\uofc\617\assignment2\garbage_data\garbage_data\CVPR_2024_dataset_Val"
TEST_PATH = r"C:\Users\john2\Desktop\uofc\617\assignment2\garbage_data\garbage_data\CVPR_2024_dataset_Test"


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Using device:", device)

# =========================
# TRANSFORMS
# =========================

transform = transforms.Compose([
    transforms.Resize((288,288)),
    transforms.CenterCrop(288),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.485, 0.456, 0.406],
        [0.229, 0.224, 0.225]
    )
])

# =========================
# LOAD IMAGE MODEL
# =========================

print("Loading image model...")

image_model = models.efficientnet_b2(weights=None)

in_features = image_model.classifier[1].in_features

image_model.classifier = nn.Sequential(
    nn.Linear(in_features, 512),
    nn.BatchNorm1d(512),
    nn.ReLU(),
    nn.Dropout(0.5),

    nn.Linear(512, 128),
    nn.BatchNorm1d(128),
    nn.ReLU(),
    nn.Dropout(0.3),

    nn.Linear(128, NUM_CLASSES),
)

image_model.load_state_dict(
    torch.load("best_image_model.pth", map_location=device)
)

image_model = image_model.to(device)
image_model.eval()

print("Image model loaded")

# =========================
# LOAD TEXT MODEL
# =========================

print("Loading text model...")

class DistilBERTClassifier(nn.Module):

    def __init__(self, num_classes):
        super().__init__()

        self.distilbert = DistilBertModel.from_pretrained(
            "distilbert-base-uncased"
        )

        self.dropout = nn.Dropout(0.3)

        self.classifier = nn.Linear(
            self.distilbert.config.hidden_size,
            num_classes
        )

    def forward(self, input_ids, attention_mask):

        outputs = self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )

        x = outputs.last_hidden_state[:,0]

        x = self.dropout(x)

        return self.classifier(x)


text_model = DistilBERTClassifier(NUM_CLASSES)

checkpoint = torch.load(
    "best_text_model.pt",
    map_location=device
)

text_model.load_state_dict(
    checkpoint["model_state_dict"]
)

text_model = text_model.to(device)
text_model.eval()

print("Text model loaded")

tokenizer = DistilBertTokenizer.from_pretrained(
    "distilbert-base-uncased"
)

# =========================
# DATASET
# =========================

class MultimodalDataset(Dataset):

    def __init__(self, path, transform, tokenizer, max_len=24):

        self.dataset = datasets.ImageFolder(
            path,
            transform=transform
        )

        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):

        image, label = self.dataset[idx]

        path = self.dataset.samples[idx][0]

        filename = os.path.basename(path)

        text = os.path.splitext(filename)[0]
        text = text.replace("_"," ")
        text = re.sub(r'\d+',"",text)

        encoding = self.tokenizer.encode_plus(
            text,
            max_length=self.max_len,
            padding="max_length",
            truncation=True,
            return_attention_mask=True,
            return_tensors="pt"
        )

        return {
            "image": image,
            "input_ids": encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze(),
            "label": label
        }

# =========================
# DATALOADERS
# =========================

train_loader = DataLoader(
    MultimodalDataset(TRAIN_PATH, transform, tokenizer),
    batch_size=BATCH_SIZE,
    shuffle=True
)

test_loader = DataLoader(
    MultimodalDataset(TEST_PATH, transform, tokenizer),
    batch_size=BATCH_SIZE,
    shuffle=False
)

# =========================
# FEATURE EXTRACTORS
# =========================

class ImageFeatureExtractor(nn.Module):

    def __init__(self, image_model):
        super().__init__()

        self.features = image_model.features
        self.avgpool = image_model.avgpool

        self.feature_layer = nn.Sequential(
            image_model.classifier[0],
            image_model.classifier[1],
            image_model.classifier[2],
            image_model.classifier[3],
            image_model.classifier[4],
            image_model.classifier[5],
            image_model.classifier[6],
            image_model.classifier[7],
        )

    def forward(self,x):

        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x,1)
        x = self.feature_layer(x)

        return x


class TextFeatureExtractor(nn.Module):

    def __init__(self,text_model):
        super().__init__()

        self.distilbert = text_model.distilbert
        self.dropout = text_model.dropout

    def forward(self,input_ids,attention_mask):

        outputs = self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )

        x = outputs.last_hidden_state[:,0]

        x = self.dropout(x)

        return x

# =========================
# FUSION MODEL
# =========================

class FusionModel(nn.Module):

    def __init__(self,image_model,text_model,num_classes):

        super().__init__()

        self.image_extractor = ImageFeatureExtractor(image_model)
        self.text_extractor = TextFeatureExtractor(text_model)

        self.classifier = nn.Sequential(

            nn.Linear(896,256),
            nn.ReLU(),
            nn.Dropout(0.3),

            nn.Linear(256,num_classes)
        )

    def forward(self,image,input_ids,attention_mask):

        image_feat = self.image_extractor(image)

        text_feat = self.text_extractor(
            input_ids,
            attention_mask
        )

        combined = torch.cat(
            (image_feat,text_feat),
            dim=1
        )

        output = self.classifier(combined)

        return output

# =========================
# CREATE FUSION MODEL
# =========================

fusion_model = FusionModel(
    image_model,
    text_model,
    NUM_CLASSES
).to(device)

# Freeze extractors

for param in fusion_model.image_extractor.parameters():
    param.requires_grad = False

for param in fusion_model.text_extractor.parameters():
    param.requires_grad = False

optimizer = optim.Adam(
    fusion_model.classifier.parameters(),
    lr=LEARNING_RATE
)

criterion = nn.CrossEntropyLoss()

# =========================
# TRAIN FUSION MODEL
# =========================

print("Training fusion model...")

for epoch in range(EPOCHS):

    fusion_model.train()

    total_loss = 0

    for batch in train_loader:

        images = batch["image"].to(device)
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)

        outputs = fusion_model(
            images,
            input_ids,
            attention_mask
        )

        loss = criterion(outputs,labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss {total_loss:.4f}")

# =========================
# TEST
# =========================

fusion_model.eval()

correct = 0
total = 0

with torch.no_grad():

    for batch in test_loader:

        images = batch["image"].to(device)
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)

        outputs = fusion_model(
            images,
            input_ids,
            attention_mask
        )

        _, preds = torch.max(outputs,1)

        correct += torch.sum(preds==labels).item()
        total += labels.size(0)

print("Fusion Accuracy:", correct/total)

Using device: cuda
Loading image model...
Image model loaded
Loading text model...
Text model loaded
Training fusion model...
Epoch 1, Loss 65.1989
Epoch 2, Loss 58.5814
Epoch 3, Loss 53.7672
Epoch 4, Loss 53.8812
Epoch 5, Loss 53.4769
Fusion Accuracy: 0.8642191142191142
