<a href="https://colab.research.google.com/github/Bordi00/Network-Security-23-24/blob/main/TTA_Memo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Init workspace
!rm -r dataset
!mkdir dataset

# Download dataset and extract it
!gdown 1WKQGHjHUkIwZT0P2TpU9h-lY-6CnrsDd
!mv imagenetv2-matched-frequency.tar.gz ./dataset
!tar -xf ./dataset/imagenetv2-matched-frequency.tar.gz
!mv imagenetv2-matched-frequency-format-val ./dataset

# Cleanup
!rm ./dataset/imagenetv2-matched-frequency.tar.gz

In [None]:
import torch
from pathlib import Path
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import numpy as np
from os import listdir, path
from os.path import basename, isfile, join
import requests
import json
from PIL import Image, ImageOps
from torchvision.models import resnet50, ResNet50_Weights
import torchvision.transforms as T
from copy import deepcopy
import pandas as pd
from torchvision.io import read_image
import random

print('PyTorch version', torch.__version__)
print('Numpy version', np.__version__)

In [None]:
!wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt -O imagenet_classes.txt

In [None]:
# Load the mapping file if it's in text format
with open("imagenet_classes.txt", "r") as f:
    class_labels = [line.strip() for line in f.readlines()]

# Assuming `imagenetv2-matched-frequency-format-val` has folders named by WNIDs
dataset_dir = './dataset/imagenetv2-matched-frequency-format-val'

# print(class_labels)
# print(listdir(dataset_dir))
folder_to_class = {int(id): class_labels[int(id)] for id in listdir(dataset_dir)}

# # Display the mapping
# for wnid, label in labels.items():
#     print(f"{wnid}: {label}")

In [None]:
# Use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device', device)

In [None]:
def load_model(weights='default'):
    return models.resnet50(weights) if weights == 'default' else models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)

In [None]:
model = load_model('default')
model.eval()

In [None]:
class ImagenetDataset(Dataset):
    def __init__(self, labels, img_dir, transform=None, target_transform=None):
        self.img_labels = [] #pathe of the image + label
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

        #zip the img path and the label
        for folder_path in labels.keys():
            # List all files in the folder and add them with the label to img_labels
            full_path = path.join(img_dir, str(folder_path))
            for img_file in listdir(full_path):
                img_path = path.join(full_path, img_file)
                if path.isfile(img_path):
                    self.img_labels.append((img_path, folder_path))

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path,label = self.img_labels[idx]
        image = Image.open(img_path)

        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label



In [None]:
def load_dataset(visualization=False) -> torch.utils.data.DataLoader:

    # Load data
    preprocess_steps = [
        transforms.Resize(256),               # Resize the shortest side to 256 pixels
        transforms.CenterCrop(224),           # Crop to 224x224 pixels around the center
        transforms.ToTensor(),                # Convert image to PyTorch tensor [0, 1] range
    ]

    # Conditionally add normalization for training/testing
    if not visualization:
        preprocess_steps.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))  # Normalize with ImageNet values

    preprocess = transforms.Compose(preprocess_steps)
    imagenet_v2_dataset = ImagenetDataset(folder_to_class, img_dir="./dataset/imagenetv2-matched-frequency-format-val",transform = preprocess)
    return imagenet_v2_dataset

In [None]:

imagenet_v2_dataset = load_dataset(True)


## Visualize image and the corresponding label

In [None]:
def visualize_images_with_desc(image_tensors, labels, figsize=(15, 6)):
    # Check if the number of images matches the number of titles
    if len(image_tensors) != len(labels):
        raise ValueError("The number of images must match the number of titles.")

    # Create a figure with the specified size
    plt.figure(figsize=figsize)

    # Loop through the images and titles to create subplots
    for i, (image, title) in enumerate(zip(image_tensors, labels)):
        image = image.permute(1, 2, 0)
        plt.subplot(1, len(image_tensors), i + 1)  # Adjust the number of columns based on the number of images
        plt.title(title)
        plt.imshow(image, cmap='gray')
        plt.axis('off')

    # Show the plot
    plt.tight_layout()  # Adjust the layout
    plt.show()


def folder_to_label(folder):
  return folder_to_class[int(folder)]

In [None]:
images = []
image_labels = []
# print(len(imagenet_v2_dataset))
for i in range(5):
  image,folder = imagenet_v2_dataset[random.randint(0, len(imagenet_v2_dataset))]
  images.append(image)
  image_labels.append(folder_to_label(folder))

visualize_images_with_desc(images, image_labels)


## Baseline Perfomance Evaluation

In [None]:
imagenet_v2_dataset = load_dataset(visualization=False)

In [None]:
from tqdm import tqdm
batch_size = 10
# Create a DataLoader for the test set
test_loader = DataLoader(imagenet_v2_dataset, batch_size=batch_size, shuffle=False)

# Initialize counters for accuracy
correct = 0
total = 0

"""
dataset need unsqueeze and squeeze
dataloader dont
`"""


# Disable gradient calculation for inference
with torch.no_grad():
    for images, ground_truth in tqdm(test_loader):
        # Forward pass through the model
        images = images.to(device)
        ground_truth = ground_truth.to(device)
        outputs = model(images)

        # Get the predicted class
        _, predicted = torch.max(outputs, 1)
        # Update the total and correct counts

        total += images.size(0) #batch size
        correct += (predicted == ground_truth).sum().item()

# Calculate accuracy
accuracy = 100 * correct / total
print(f'Accuracy of the model on the test set: {accuracy:.2f}%')


## Memo


In [None]:
from torchvision.transforms import AugMix, InterpolationMode, ToPILImage, ToTensor

def create_augmented_batch(image, mode='augmix', n=8):
    """
    Create a batch of augmented images using AugMix.
    """
    image = ToPILImage()(image)
    preaugment = transforms.Compose([
        AugMix(severity=10, mixture_width=2),
        transforms.Resize(224, interpolation=InterpolationMode.BICUBIC),
        transforms.CenterCrop(224),
        ToTensor()
    ])
    augmentations = [preaugment(image) for _ in range(n)]

    image = ToTensor()(image)

    return torch.stack([image] + augmentations)

In [None]:
def show_image(img):
    plt.imshow(img.squeeze(0).permute(1, 2, 0))

def show_batch_images(batch_tensor):
    batch_size = batch_tensor.shape[0]
    fig, axs = plt.subplots(1, batch_size, figsize=(batch_size * 3, 3))

    if batch_size == 1:
        axs = [axs]

    for i, ax in enumerate(axs):
        img = T.ToPILImage()(batch_tensor[i])
        ax.imshow(img)
        ax.axis('off')

    plt.show()

In [None]:
test_image, test_label = next(iter(imagenet_v2_dataset))
show_image(test_image)

In [None]:
batch = create_augmented_batch(test_image, n=8)
show_batch_images(batch)

In [None]:
test_loader = DataLoader(imagenet_v2_dataset, batch_size=1, shuffle=True)

for image in tqdm(test_loader):
    batch = create_augmented_batch(image, n=8)
    original_model = deepcopy(model)
