In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, datasets, transforms
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertTokenizer, DistilBertModel
import os
import re
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
import numpy as np

# imports and check if cuda is avaiable

NUM_CLASSES = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
This note book loads the checkpoints for the best text model and best image model. The test data in the 15GB garbage_data set is fed to the model. The result is combined based on confidence selection rules. 

Note: We attempted to load the full saved model, but it resulted in lower accuracy compared to loading from checkpoints. In this notebook, the model architecture is explicitly defined, and the checkpoint weights are loaded into that architecture.

EfficientNet(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): MBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
          (1): SqueezeExcitation(
            (avgpool): AdaptiveAvgPool2d(output_size=1)
            (fc1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
            (fc2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
            (activation): SiLU(inplace=True)
            (scale_activation): Sigmoid()
          )
          (2): Conv2dNormActivat

In [5]:
#redefined the architecture for text model
class DistilBERTClassifier(nn.Module):

    def __init__(self, num_classes):

        super().__init__()

        self.distilbert = DistilBertModel.from_pretrained(
            "distilbert-base-uncased"
        )

        self.dropout = nn.Dropout(0.3)

        self.classifier = nn.Linear(
            self.distilbert.config.hidden_size,
            num_classes
        )

        # Freeze ALL DistilBERT layers initially
        for param in self.distilbert.parameters():
            param.requires_grad = False


    def forward(self, input_ids, attention_mask):

        outputs = self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )

        cls_output = outputs.last_hidden_state[:, 0]

        x = self.dropout(cls_output)

        return self.classifier(x)

In [None]:
#load the text model
text_model = DistilBERTClassifier(NUM_CLASSES)

checkpoint = torch.load("best_text_model.pth", map_location=device)

text_model.load_state_dict(checkpoint["model_state_dict"])
text_model = text_model.to(device)
text_model.eval()

tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

In [15]:
import os
import re
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms

#define the dataset
class MultimodalDataset(Dataset):

    def __init__(self, image_dir, transform, tokenizer, max_len=24):
        #ImageFolder automatically assigns labels
        self.image_dataset = datasets.ImageFolder(image_dir, transform=transform)
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.image_dataset)

    def __getitem__(self, idx):
        #load image and label
        image, label = self.image_dataset[idx]
        # get the full file path
        path = self.image_dataset.samples[idx][0]
        # get file name e.g. plastic_bag.png
        filename = os.path.basename(path)
        # get item name e.g. plastic_bag
        text = os.path.splitext(filename)[0]
        text = text.replace('_', ' ')

        # remove numerical value in name
        text = re.sub(r'\d+', '', text)
        ######## Talc version ###########
        # encoding = tokenizer.encode_plus(
        #     text,
        #     add_special_tokens=True,
        #     max_length=self.max_len,
        #     padding='max_length',
        #     truncation=True,
        #     return_attention_mask=True,
        #     return_tensors='pt'
        # )

        # return {
        #     "image": image,
        #     "input_ids": encoding["input_ids"].squeeze(),
        #     "attention_mask": encoding["attention_mask"].squeeze(),
        #     "label": label
        # }
        ######### Colab version ########
        encoding = tokenizer(
        text,
        add_special_tokens=True,
        max_length=self.max_len,
        padding='max_length',
        truncation=True,
        return_attention_mask=True,
        return_tensors="pt"
        )

        return {
            "image": image,
            "input_ids": encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze(),
            "label": label
        }

In [None]:
#DATALOADER
# image augmentation needed for the image model
transform_test = transforms.Compose([
    transforms.Resize((288, 288)),
    transforms.CenterCrop(288),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.485, 0.456, 0.406],
        [0.229, 0.224, 0.225]
    ),
])



TEST_PATH=  "/content/drive/MyDrive/dataTest"
test_dataset = MultimodalDataset(
    TEST_PATH,
    transform_test,
    tokenizer
)

test_loader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False
)

Confusion Matrix function

In [None]:


def plot_confusion_matrix(y_true, y_pred,
                          title='Confusion Matrix',
                          normalize=False):
    class_names = ['Black','Blue','Green','TTR']
    cm = confusion_matrix(y_true, y_pred)

    if normalize:
        cm = cm.astype(float) / cm.sum(axis=1, keepdims=True)
        fmt = ".2f"
    else:
        fmt = "d"

    plt.figure(figsize=(8,6))

    sns.heatmap(
        cm,
        annot=True,
        fmt=fmt,
        cmap='Blues',
        xticklabels=class_names,
        yticklabels=class_names
    )

    plt.title(title)
    plt.xlabel('Predicted')
    plt.ylabel('True')

    plt.show()


Misclassification function


In [None]:


def show_misclassified(image_model, text_model, test_loader, device, max_show=8):
    class_names = ['Black','Blue','Green','TTR']
    image_model.eval()
    text_model.eval()

    shown = 0
    plt.figure(figsize=(12,6))

    with torch.no_grad():

        for batch in test_loader:

            images = batch["image"].to(device)
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            # logits
            image_logits = image_model(images)
            text_logits = text_model(input_ids, attention_mask)

            # probabilities
            image_probs = F.softmax(image_logits, dim=1)
            text_probs = F.softmax(text_logits, dim=1)

            # confidence selection
            image_conf, image_pred = torch.max(image_probs, dim=1)
            text_conf, text_pred = torch.max(text_probs, dim=1)

            use_image = image_conf > text_conf
            final_pred = torch.where(use_image, image_pred, text_pred)

            for i in range(len(final_pred)):

                if final_pred[i] != labels[i] and shown < max_show:

                    img = images[i].cpu().permute(1,2,0)

                    plt.subplot(2,4,shown+1)
                    plt.imshow(img)
                    plt.axis("off")

                    pred_name = class_names[final_pred[i]]
                    true_name = class_names[labels[i]]

                    plt.title(f"Prediction:{pred_name}\nAccurate:{true_name}")

                    shown += 1

                if shown >= max_show:
                    break

            if shown >= max_show:
                break

    plt.suptitle("Misclassified Examples")
    plt.show()

In [None]:
#PREDICTION
import torch.nn.functional as F
import numpy as np

all_preds = []
all_labels = []

image_model.eval()
text_model.eval()

with torch.no_grad():

    for batch in test_loader:

        images = batch["image"].to(device)
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)

        # Get logits
        image_logits = image_model(images)
        text_logits = text_model(input_ids, attention_mask)

        # Convert to probabilities
        image_probs = F.softmax(image_logits, dim=1)
        text_probs = F.softmax(text_logits, dim=1)

        # Get confidence + prediction
        image_conf, image_pred = torch.max(image_probs, dim=1)
        text_conf, text_pred = torch.max(text_probs, dim=1)

        # CONFIDENCE SELECTION
        # Check which model has a higher confidence in its result and use that result
        # For example image model output (0.7,0.2,0.1,0.1), text model output (0.6,0.2,0.2,0); the code will choose the result of the image model
        use_image = image_conf > text_conf

        final_pred = torch.where(use_image, image_pred, text_pred)

        all_preds.extend(final_pred.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())



In [18]:
accuracy = np.mean(np.array(all_preds) == np.array(all_labels))

print("Confidence Fusion Accuracy:", accuracy)

Confidence Fusion Accuracy: 0.8257575757575758


In [19]:
import torch.nn.functional as F
import numpy as np

all_preds = []
all_labels = []

image_model.eval()
text_model.eval()

with torch.no_grad():

    for batch in test_loader:

        images = batch["image"].to(device)
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)

        # Get logits
        image_logits = image_model(images)
        text_logits = text_model(input_ids, attention_mask)

        # Convert to probabilities
        image_probs = F.softmax(image_logits, dim=1)
        text_probs = F.softmax(text_logits, dim=1)

        # FUSION (average probabilities)
        fused_probs = (image_probs + text_probs) / 2

        # Final prediction
        fused_pred = torch.argmax(fused_probs, dim=1)

        all_preds.extend(fused_pred.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

In [20]:
accuracy = np.mean(np.array(all_preds) == np.array(all_labels))

print("Confidence Fusion Accuracy:", accuracy)

Confidence Fusion Accuracy: 0.872086247086247
