In [None]:
!pip install torch
!pip install sklearn

In [1]:
import os
from tqdm.auto import tqdm

import numpy as np
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms

from sklearn.metrics import confusion_matrix

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
root_path = "/content/drive/MyDrive/NUS/CS4243/CS4243_mini_project"
model_root_path = os.path.join(root_path, "models")

image_data_path = os.path.join(root_path, "image_data_cleaned_split/test")
spec_data_path = os.path.join(root_path, "spectrogram_data_split/test")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

image_model_name = "inception_ensemble_image_classifier_lr3_e20_elr7"
spec_model_name = "inception_ensemble_spectrogram_classifier_lr3_e20_elr9"

In [4]:
def create_model(num_classes: int = 3, model_path = None):
    model = torch.hub.load('pytorch/vision:v0.10.0', 'inception_v3', pretrained=model_path is None)
    model.AuxLogits.fc = nn.Linear(768, num_classes)
    model.fc = nn.Linear(2048, num_classes)
    if model_path:
        model.load_state_dict(torch.load(model_path))
    return model.to(device)

In [5]:
image_model = create_model(model_path=os.path.join(model_root_path, image_model_name))
image_model.eval()
spec_model = create_model(model_path=os.path.join(model_root_path, spec_model_name))
spec_model.eval()
print("Models loaded!")

Downloading: "https://github.com/pytorch/vision/zipball/v0.10.0" to /root/.cache/torch/hub/v0.10.0.zip


KeyboardInterrupt: ignored

In [None]:
input_shape = (299, 299)

# Image transformations
img_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomRotation(20),
    transforms.Resize(input_shape),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

spec_transform = transforms.Compose([
    transforms.Resize(input_shape),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
class ImageSpecDataset(Dataset):

    def __init__(self, image_data_path, spec_data_path):
        labels = [x for x in os.listdir(image_data_path) if os.path.isdir(os.path.join(image_data_path, x))]
        all_image_data = []
        all_image_name_to_spec_data_map = {}
        for i, label in enumerate(labels):
            label_image_path = os.path.join(image_data_path, label)
            label_spec_path = os.path.join(spec_data_path, label)
            image_names = [x for x in os.listdir(label_image_path) if os.path.isfile(os.path.join(label_image_path, x))]
            spec_names = [x for x in os.listdir(label_spec_path) if os.path.isfile(os.path.join(label_spec_path, x))]
            print(f'\nChecking Test Dataset of {label}')
            print(f'Checking spectrogram files')
            for spec_name in tqdm(spec_names):
                spec_data = Image.open(os.path.join(label_spec_path, spec_name)).convert('RGB')
                original_name, _, _ = spec_name.rpartition("_audio")
                for image_name in image_names:
                    if original_name in image_name:
                        all_image_name_to_spec_data_map[image_name] = spec_data
            print(f'Checking image files')
            for image_name in tqdm(image_names):
                image_data = Image.open(os.path.join(label_image_path, image_name))
                image_data = img_transform(image_data)
                all_image_data.append((image_data, image_name, i))
        self.image_data = all_image_data
        self.image_name_to_spec_data_map = all_image_name_to_spec_data_map
        print(len(all_image_data), "images found.")
        print(len(all_image_name_to_spec_data_map), "images with spectrograms found.")

    def __len__(self):
        return len(self.image_data)

    def __getitem__(self, i):
        image, image_name, label = self.image_data[i]
        if image_name in self.image_name_to_spec_data_map:
            spec = self.image_name_to_spec_data_map[image_name]
            return image, spec_transform(spec), label, 1
        else:
            spec = Image.new('RGB', input_shape)
            return image, spec_transform(spec), label, 0

In [None]:
dataset = ImageSpecDataset(image_data_path, spec_data_path)


Checking Test Dataset of carrying
Checking spectrogram files


  0%|          | 0/50 [00:00<?, ?it/s]

Checking image files


  0%|          | 0/179 [00:00<?, ?it/s]


Checking Test Dataset of normal
Checking spectrogram files


  0%|          | 0/54 [00:00<?, ?it/s]

Checking image files


  0%|          | 0/261 [00:00<?, ?it/s]


Checking Test Dataset of threat
Checking spectrogram files


  0%|          | 0/57 [00:00<?, ?it/s]

Checking image files


  0%|          | 0/142 [00:00<?, ?it/s]

582 images found.
518 images with spectrograms found.


In [None]:
def collator(batch):
    images, specs, labels, masks = zip(*batch)
    images = torch.stack(images).float()
    specs = torch.stack(specs).float()
    labels = torch.tensor(labels)
    masks = torch.tensor(masks)
    return images, specs, labels, masks

In [None]:
batch_size = 16

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collator, num_workers=2)

In [None]:
def get_labels(logit, size):
    return torch.max(logit, dim=1)[1].view(size)

def get_accuracy(logit, target, batch_size):
    corrects = (get_labels(logit, target.size()).data == target.data).sum()
    accuracy = 100.0 * corrects/batch_size
    return accuracy.item()

In [None]:
test_acc = 0
image_weight = 0.93
spec_weight = 0.57
total_conf_table = np.zeros((3, 3))

pbar = tqdm(total=len(dataloader))

for i, (images, specs, labels, spec_mask) in enumerate(dataloader):
    images = images.to(device)
    specs = specs.to(device)
    labels = labels.to(device)
    spec_mask = spec_mask.to(device)

    # ensemble forward step
    image_logit = image_model(images)
    spec_logit = spec_model(specs)
    masked_spec = torch.transpose(spec_mask * torch.transpose(spec_logit, 0, 1), 0, 1)

    # combine ensemble logits
    logit = (image_logit * image_weight + masked_spec  * spec_weight) / (image_weight + spec_weight)

    # get predicted labels and calculate metrics
    pred = get_labels(logit, labels.size())
    test_acc += get_accuracy(logit, labels, batch_size)
    pred_np = pred.cpu().detach().numpy()
    label_np = labels.cpu().detach().numpy()
    conf_table = confusion_matrix(label_np, pred_np, labels=[0, 1, 2])
    total_conf_table += conf_table
    pbar.update(1)
pbar.close()


print("Test Accuracy: %.2f" %(test_acc/i)) 
print("Confusion Table:")
print(total_conf_table)

  0%|          | 0/37 [00:00<?, ?it/s]

Test Accuracy: 82.64
Confusion Table:
[[113.  37.  29.]
 [  7. 230.  24.]
 [  1.   8. 133.]]
