## HW3 Image Classification
#### Solve image classification with convolutional neural networks(CNN).
#### If you have any questions, please contact the TAs via TA hours, NTU COOL, or email to mlta-2023-spring@googlegroups.com

### Import Packages

In [None]:
# Import necessary packages.
import gc
import numpy as np
import pandas as pd
import torch
import os
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
# "ConcatDataset" and "Subset" are possibly useful when doing semi-supervised learning.
from torch.utils.data import ConcatDataset, DataLoader, Subset, Dataset
from torchvision.datasets import DatasetFolder, VisionDataset
# This is for the progress bar.
from tqdm.auto import tqdm
import random

In [None]:
myseed = 1091102  # set a random seed for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(myseed)
torch.manual_seed(myseed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(myseed)

### Transforms

In [None]:
# Normally, We don't need augmentations in testing and validation.
# All we need here is to resize the PIL image and transform it into Tensor.

image_size = (256,256)
test_tfm = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
])

# However, it is also possible to use augmentation in the testing phase.
# You may use train_tfm to produce a variety of images and then test using ensemble methods
train_tfm = transforms.Compose([
    # Resize the image into a fixed shape (height = width = 128)
    transforms.Resize(image_size),
    transforms.RandomRotation(45),
    transforms.RandomCrop((200,200),padding_mode="edge"),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomAdjustSharpness(1.5, p=0.5),
    transforms.RandomAutocontrast(p=0.5),
    transforms.RandomPosterize(5, p=0.5),
    transforms.RandomPerspective(distortion_scale=0.15,p=0.5),
#     transforms.ElasticTransform(alpha=10.0),
#     transforms.AugMix(1),
    transforms.ToTensor(),
    transforms.RandomErasing(p=0.5,scale=(0.005,0.005),value=(random.random(),random.random(),random.random())),
    transforms.RandomErasing(p=0.5,scale=(0.005,0.005),value=(random.random(),random.random(),random.random())),
    transforms.RandomErasing(p=0.5,scale=(0.005,0.005),value=(random.random(),random.random(),random.random())),
    transforms.RandomErasing(p=0.5,scale=(0.005,0.005),value=(random.random(),random.random(),random.random())),
])

### Datasets

In [None]:
class FoodDataset(Dataset):

    def __init__(self,path,tfm=test_tfm,files = None):
        super(FoodDataset).__init__()
        self.path = path
        self.files = sorted([os.path.join(path,x) for x in os.listdir(path) if x.endswith(".jpg")])
        if files != None:
            self.files = files
            
        self.transform = tfm
  
    def __len__(self):
        return len(self.files)
  
    def __getitem__(self,idx):
        fname = self.files[idx]
        im = Image.open(fname)
        im = self.transform(im)
        
        try:
            label = int(fname.split("/")[-1].split("_")[0])
        except:
            label = -1 # test has no label
            
        return im,label

In [None]:
class FoodDataset_TTA(Dataset):

    def __init__(self,path, train_tfm , test_tfm , TTA_num , files = None):
        super(FoodDataset).__init__()
        self.path = path
        self.files = sorted([os.path.join(path,x) for x in os.listdir(path) if x.endswith(".jpg")])
        if files != None:
            self.files = files
            
        self.train_transform = train_tfm
        self.test_transform = test_tfm
  
    def __len__(self):
        return len(self.files)
  
    def __getitem__(self,idx):
        fname = self.files[idx]
        im = Image.open(fname)
        train_im = []
        for i in range(TTA_num):
             train_im.append(self.train_transform(im))
            
        test_im = self.test_transform(im)
        
        
        try:
            label = int(fname.split("/")[-1].split("_")[0])
        except:
            label = -1 # test has no label
            
        return train_im,test_im,label

### Model

In [None]:
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        # torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        # torch.nn.MaxPool2d(kernel_size, stride, padding)
        # input 維度 [3, 128, 128]
#         self.cnn = models.efficientnet_v2_m()
        self.cnn = models.resnext101_32x8d()
        self.fc = nn.Linear(1000 , 11)
        
    def forward(self, x):
        out = self.cnn.conv1(x) 
        out = self.cnn.bn1(out)
        out = self.cnn.relu(out) 
        out = self.cnn.maxpool(out)
        out = self.cnn.layer1(out)
        out = self.cnn.layer2(out)
        out = self.cnn.layer3(out)
#         out = self.cnn(x)
#         out = self.fc(out)
        return out

### Configurations

In [None]:

batch_size = 32
TTA_num = 5
TTA_ratio = 0.8
train_valid_ratio = 0.9


### Dataloader

In [None]:
files =  [os.path.join("/kaggle/input/ml2023spring-hw3/train",x) for x in os.listdir("/kaggle/input/ml2023spring-hw3/train") if x.endswith(".jpg")]
files += [os.path.join("/kaggle/input/ml2023spring-hw3/valid",x) for x in os.listdir("/kaggle/input/ml2023spring-hw3/valid") if x.endswith(".jpg")]
label_file = [[] for i in range(11)]
train_file = []
valid_file = []
for file in files:
    label = int(file.split("/")[-1].split("_")[0])
    label_file[label].append(file)
for idx , label in enumerate(label_file):
    random.shuffle(label_file[idx])
    pick_num =  int(len(label) * train_valid_ratio)
    
    for num , file in enumerate(label):
        if num < pick_num:
            train_file.append(file)
        else:
            valid_file.append(file)
train_file.sort()
valid_file.sort()

del files , label_file
gc.collect()

In [None]:
# Construct train and valid datasets.
# The argument "loader" tells how torchvision reads the data.

train_set = FoodDataset("/kaggle/input/ml2023spring-hw3/train", tfm=train_tfm , files = train_file)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
valid_set = FoodDataset("/kaggle/input/ml2023spring-hw3/valid", tfm=test_tfm , files = valid_file)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)

print(len(train_set))
print(len(valid_set))

# t-SNE

In [None]:
import torch
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from tqdm import tqdm
import matplotlib.cm as cm
import torch.nn as nn

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load the trained model
read_model = '/kaggle/input/models/sample_best.ckpt'
model = Classifier().to(device)
state_dict = torch.load(read_model , map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
model.eval()
# del state_dict
# gc.collect()
print("!")




In [None]:
# print(model.cnn)

# new_model = nn.Sequential(model.cnn.conv1 , 
#                           model.cnn.bn1 ,
#                           model.cnn.relu ,
#                         model.cnn.maxpool,
#                           model.cnn.layer1
#                          )
# new_model.eval()
# print("!")

In [None]:
# del model , train_set , train_loader
# gc.collect()

In [None]:
# index = 11 # You should find out the index of layer which is defined as "top" or 'mid' layer of your model.
features = []
labels = []
for batch in tqdm(valid_loader):
    imgs, lbls = batch
    with torch.no_grad():
        logits = model(imgs.to(device))
        logits = logits.view(logits.size()[0], -1)
    labels.extend(lbls.cpu().numpy())
    logits = np.squeeze(logits.cpu().numpy())
    features.extend(logits)
    
    
features = np.array(features)
colors_per_class = cm.rainbow(np.linspace(0, 1, 11))

print("T-SNE")
del valid_set , valid_loader
gc.collect()


In [None]:
print(features.shape)

In [None]:
# Apply t-SNE to the features
features_tsne = TSNE(n_components=2, init='pca', random_state=42).fit_transform(features)

# Plot the t-SNE visualization
plt.figure(figsize=(10, 8))
for label in np.unique(labels):
    plt.scatter(features_tsne[labels == label, 0], features_tsne[labels == label, 1], label=label, s=5)
plt.legend()
plt.savefig("t_sne.jpg")
plt.show()