# Custom Dataset Definition for Deep Learning Classification

## Task:

Your task is to define a custom dataset for deep learning-based classification.

## Instructions:

Follow the steps below to define a custom dataset for classification:

1. Create a new Python class inheriting from `torch.utils.data.Dataset`.
2. Implement the `__init__` method to initialize the dataset.
3. Implement the `__len__` method to return the total number of samples in the dataset.
4. Implement the `__getitem__` method to return a sample and its corresponding label.
5. Load and preprocess the data within the `__init__` method or within the `__getitem__` method.
6. Return the sample and label in the `__getitem__` method.

Once you have defined the custom dataset class, you can use it with PyTorch's DataLoader to load and iterate over the data during training or evaluation.

In this practical session, we use another 10-category classification dataset STL-10. The processed STL-10 dataset can be downloaded from
https://drive.google.com/file/d/18j4RTzC5uCqT6QV96-oa-FNGztfDx5oH/view?usp=drive_link

Download and unzip the data, making the directory structure looks like
```bash
.
├── COMP8430_Practical_Week_3.ipynb
└── stl-10
    ├── test_images
    │    ├── test_image_png_1.png
    │    ├── test_image_png_2.png
    │    ├── test_image_png_3.png
    │    └── ......
    ├── test.json
    ├── train_images
    │    ├── train_image_png_1.png
    │    ├── train_image_png_2.png
    │    ├── train_image_png_3.png
    │    └── ......
    └── train.json
```

In [None]:
# Install necessary packages
!pip install torch torchvision Pillow

In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Step 1: Data Preprocessing and Augmentation

Define the STL-10 dataset, apply transformations for data preprocessing and augmentation, and create data loaders for training and validation.

In [None]:
from torch.utils.data import Dataset,DataLoader
import json
from PIL import Image
from os.path import join

# Implement a Dataset to load the STL-10 dataset for training and evaluation
class STL10(Dataset):
    def __init__(self,train):
        super().__init__()
        self.train=train

        # Read the file names and labels
        self.label=[]
        self.image_path=[]
        if self.train:
            with open("./stl-10/train.json","r") as f:
                file_label_dict=json.load(f)
        else:
            with open("./stl-10/test.json","r") as f:
                file_label_dict=json.load(f)
        for data_dict in file_label_dict:
            self.image_path.append(data_dict["file"])
            self.label.append(data_dict["label"])
        
        # define the data transforms
        # NOTE add your implementation of train / test transforms
        if train:
            self.transform = None
        else:
            self.transform=None
    def __len__(self):
        # NOTE add your implementation to return the length of the dataset
        return None
    def __getitem__(self, index):
        # NOTE add your implementation to return the item
        return None
    
# NOTE add your implementation to cretae STL-10 dataset
train_dataset = None
test_dataset = None

# NOTE add your implementation to create data loaders
train_loader = None 
test_loader = None

## Step 2: Define the Lightweight CNN Model

Define a lightweight CNN model for image classification with additional layers and techniques for improved performance.

In [None]:
import torch.nn.functional as F
# Define the CNN model
class LightweightCNN(nn.Module):
    def __init__(self):
        super(LightweightCNN, self).__init__()
        self.conv_bn1 = nn.Sequential(nn.Conv2d(3, 48, 3, padding=1),nn.BatchNorm2d(48))
        self.conv_bn2_1 = nn.Sequential(nn.Conv2d(48, 96, 3, padding=1),nn.BatchNorm2d(96))
        self.conv_bn2_2 = nn.Sequential(nn.Conv2d(96, 96, 3, padding=1),nn.BatchNorm2d(96))
        self.conv_bn2_res=nn.Sequential(nn.Conv2d(48,96,1,stride=2),nn.BatchNorm2d(96))
        self.conv_bn3_1 = nn.Sequential(nn.Conv2d(96, 192, 3, padding=1),nn.BatchNorm2d(192))
        self.conv_bn3_2 = nn.Sequential(nn.Conv2d(192, 192, 3, padding=1),nn.BatchNorm2d(192))
        self.conv_bn3_res=nn.Sequential(nn.Conv2d(96,192,1,stride=2),nn.BatchNorm2d(192))
        self.pool = nn.MaxPool2d(2, 2)
        self.avg_pool=nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(192, 768)
        self.dropout = nn.Dropout(0.1)
        self.fc2 = nn.Linear(768, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv_bn1(x)))
        x_res=self.conv_bn2_res(x)
        x = self.pool(F.relu(self.conv_bn2_1(x)))
        x=self.conv_bn2_2(x)+x_res
        x=F.relu(x)
        x_res=self.conv_bn3_res(x)
        x = self.pool(F.relu(self.conv_bn3_1(x)))
        x=self.conv_bn3_2(x)+x_res
        x=F.relu(x)
        x = self.avg_pool(x).view(x.shape[0], -1)
        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# Create an instance of the lightweight CNN model
model = LightweightCNN()
print(model)

## Step 3: Train the Model

Train the lightweight CNN model using the training dataset with techniques such as learning rate scheduling and model checkpointing.

In [None]:
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)

# Define learning rate scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.1)

# Train the model
model.to(device)
num_epochs = 12
best_val_loss = float('inf')
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    # Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
    val_loss /= len(test_loader)
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader):.4f}, Val Loss: {val_loss:.4f}')
    scheduler.step()
    
    # Save the model if validation loss has decreased
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_model.pth')

## Step 4: Evaluate the Model

Evaluate the performance of the trained model on the test dataset.

In [None]:
# Load the best model
best_model = LightweightCNN()
best_model.load_state_dict(torch.load('best_model.pth'))
best_model.to(device)

# Evaluate the model
best_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = best_model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

# Print accuracy
accuracy = 100 * correct / total
print(f'Test Accuracy: {accuracy:.2f}%')

## Step 5: Run the Code on the Robot

1. Save the code in this notebook as a python script (.py file). 
2. Compress the python script and the dataset.
3. Power on the robot, wait untill it fully boots up, and connect your laptop to the WIFI shared by the robot. The password is *hiwonder*
4. Using NoMachine to connect to the robot's Linux operating system. Both the username and password are *ubuntu*
5. Navigate to */home/ubuntu/COMP8430/*. If the *COMP8430* folder does not exist, create it. Then, create a new folder named after your MQ ID inside *COMP8430*.
6. using NomMachine to upload the ZIP file to your folder and then extract its contents.
7. Open a terminal in your folder and run the python script using the *python* command.

