In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, datasets, transforms
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertTokenizer, DistilBertModel
import os
import re
import numpy as np
# imports and check if cuda is avaiable

NUM_CLASSES = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

This note book loads the checkpoints for the best text model and best image model. The test data in the 15GB garbage_data set is fed to the model. The result is combined based on confidence selection rules. 

Note: We attempted to load the full saved model, but it resulted in lower accuracy compared to loading from checkpoints. In this notebook, the model architecture is explicitly defined, and the checkpoint weights are loaded into that architecture.

In [None]:
#redefined the architecture for text model
class DistilBERTClassifier(nn.Module):

    def __init__(self, num_classes):

        super().__init__()

        self.distilbert = DistilBertModel.from_pretrained(
            "distilbert-base-uncased"
        )

        self.dropout = nn.Dropout(0.3)

        self.classifier = nn.Linear(
            self.distilbert.config.hidden_size,
            num_classes
        )

        # Freeze ALL DistilBERT layers initially
        for param in self.distilbert.parameters():
            param.requires_grad = False


    def forward(self, input_ids, attention_mask):

        outputs = self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )

        cls_output = outputs.last_hidden_state[:, 0]

        x = self.dropout(cls_output)

        return self.classifier(x)

In [None]:
#load the text model
text_model = DistilBERTClassifier(NUM_CLASSES)

checkpoint = torch.load("best_text_model.pth", map_location=device)

text_model.load_state_dict(checkpoint["model_state_dict"])
text_model = text_model.to(device)
text_model.eval()

tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

In [44]:
# redefined the image model architecture and load the image model
# Create base EfficientNet model
image_model = models.efficientnet_b2(weights=None)

# Recreate the SAME classifier used during training
in_features = image_model.classifier[1].in_features

image_model.classifier = nn.Sequential(
    nn.Linear(in_features, 512),
    nn.BatchNorm1d(512),
    nn.ReLU(),
    nn.Dropout(0.5),

    nn.Linear(512, 128),
    nn.BatchNorm1d(128),
    nn.ReLU(),
    nn.Dropout(0.3),

    nn.Linear(128, NUM_CLASSES),
)

# Load trained weights
image_model.load_state_dict(
    torch.load("best_image_model.pth", map_location=device, weights_only=True)
)

# Move to device and set eval mode
image_model = image_model.to(device)
image_model.eval();

In [None]:

#define the dataset
class MultimodalDataset(Dataset):

    def __init__(self, image_dir, transform, tokenizer, max_len=24):
        #ImageFolder automatically assigns labels
        self.image_dataset = datasets.ImageFolder(image_dir, transform=transform)
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.image_dataset)

    def __getitem__(self, idx):
        #load image and label
        image, label = self.image_dataset[idx]
        # get the full file path
        path = self.image_dataset.samples[idx][0]
        # get file name e.g. plastic_bag.png
        filename = os.path.basename(path)
        # get item name e.g. plastic_bag
        text = os.path.splitext(filename)[0]
        text = text.replace('_', ' ')

        # remove numerical value in name
        text = re.sub(r'\d+', '', text)

        encoding = tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        return {
            "image": image,
            "text": text,
            "input_ids": encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze(),
            "label": label
        }

In [None]:
#DATALOADER
# image augmentation needed for the image model
transform_test = transforms.Compose([
    transforms.Resize((288, 288)),
    transforms.CenterCrop(288),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.485, 0.456, 0.406],
        [0.229, 0.224, 0.225]
    ),
])

TEST_PATH = r"C:\Users\john2\Desktop\uofc\617\assignment2\garbage_data\garbage_data\CVPR_2024_dataset_Test"
test_dataset = MultimodalDataset(
    TEST_PATH,
    transform_test,
    tokenizer
)

test_loader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False
)

In [None]:
#PREDICTION
import torch.nn.functional as F
import numpy as np

all_preds = []
all_labels = []

image_model.eval()
text_model.eval()

with torch.no_grad():

    for batch in test_loader:

        images = batch["image"].to(device)
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)

        # Get logits
        image_logits = image_model(images)
        text_logits = text_model(input_ids, attention_mask)

        # Convert to probabilities
        image_probs = F.softmax(image_logits, dim=1)
        text_probs = F.softmax(text_logits, dim=1)

        # Get confidence + prediction
        image_conf, image_pred = torch.max(image_probs, dim=1)
        text_conf, text_pred = torch.max(text_probs, dim=1)

        # CONFIDENCE SELECTION
        # Check which model has a higher confidence in its result and use that result
        # For example image model output (0.7,0.2,0.1,0.1), text model output (0.6,0.2,0.2,0); the code will choose the result of the image model
        use_image = image_conf > text_conf

        final_pred = torch.where(use_image, image_pred, text_pred)

        all_preds.extend(final_pred.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

In [None]:
#ACCURACY
accuracy = np.mean(np.array(all_preds) == np.array(all_labels))

print("Confidence Fusion Accuracy:", accuracy)

Confidence Fusion Accuracy: 0.8703379953379954
