In [10]:
# Imports
%reset -f
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import wandb
import torch.optim as optim
from datasets.CactusDataset import CactusDataset
from models.LeNet5 import LeNet5
from torch.utils.data import ConcatDataset
import os


In [11]:
def get_label_distribution(dataset):
    label_counts = {}
    for _, _, label in dataset:
        if label not in label_counts:
            label_counts[label] = 0
        label_counts[label] += 1
    sorted_distribution = sorted(label_counts.items(), key=lambda x: x[0], reverse=True)
    # take only the count
    return [x[1] for x in sorted_distribution]


# load data for displaying
dataset = CactusDataset(root_dir='./data/train/train',labels_path='./data/train.csv')
# take the first sample from train_dataloader
_, train_features, train_labels = dataset[0]
image_np = np.array(train_features)
print("Image shape: "+str(image_np.shape))
print("Image python class"+str(type(train_features)))
print("Label: "+str(train_labels))

label_distribution = get_label_distribution(dataset) # they are sorted in ascending order
print(label_distribution)
fig, ax = plt.subplots()
ax.pie(label_distribution, labels=['no cactus','cactus'], autopct='%1.1f%%')
plt.show()

In [12]:
transform_dataset = transforms.Compose([
    transforms.ToTensor()
])

dataset = CactusDataset(
    root_dir="./data/train/train",
    labels_path="./data/train.csv",
    transform=transform_dataset 
)
_, image, label = dataset[0]
print("Image python class" + str(type(image)))
print("Image shape: " + str(image.shape))
print("Label shape: " + str(label))
# show the image
plt.imshow(image.permute(1,2,0))

In [13]:
dataloader = DataLoader(dataset,batch_size=32,shuffle=True)

for i, (img_name, images, labels) in enumerate(dataloader):
    print("Batch number: " + str(i))
    print("Batch image names: " + str(img_name))
    print("Batch images shape: " + str(images.shape))
    print("Batch labels shape: " + str(labels.shape))
    break

In [14]:
def compute_mean_std(dataset):
    mean = 0.
    std = 0.
    for _, images, _ in dataset:
        mean += images.mean()
        std += images.std()
    mean /= len(dataset)
    std /= len(dataset)
    return mean, std

# --------- DATA AUGMENTATION ---------
# filter the dataset to only have no cactus images
dataset = CactusDataset(
    root_dir="./data/train/train",
    labels_path="./data/train.csv",
    transform=transform_dataset
)

no_cactus_dataset = dataset.filter(0)

# create a concatenated dataset with an equal number of cactus and no cactus images
print("**** BEFORE ****")
print("Number of no cactus images: " + str(no_cactus_dataset.__len__()))
print("Number of cactus images: " + str(dataset.__len__()-no_cactus_dataset.__len__()))

mean, std = compute_mean_std(ConcatDataset([dataset, no_cactus_dataset, no_cactus_dataset]))
print("Mean: " + str(mean))
print("Std: " + str(std))

# merged_dataset_transforms=transforms.Compose([
#     transforms.RandomHorizontalFlip(),
#     transforms.RandomVerticalFlip(),
#     transforms.RandomRotation(20),
#     transforms.RandomEqualize(1),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=mean, std=std),
# ])

# merged_dataset_transforms=transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize(mean=mean, std=std),
# ])

merged_dataset_transforms=transforms.Compose([
    transforms.ToTensor(),
])

dataset = CactusDataset(
    root_dir="./data/train/train",
    labels_path="./data/train.csv",
    transform=merged_dataset_transforms
)

added_dataset = CactusDataset(
    root_dir = "./data/test/test",
    labels_path = "./data/test.csv",
    transform = merged_dataset_transforms
)

print("Added dataset length: " + str(len(added_dataset)))

no_cactus_dataset = dataset.filter(0)
dataset_merged = ConcatDataset([dataset,no_cactus_dataset,no_cactus_dataset])

dataset_merged_len = 0
no_cactus_merged_len = 0
for set in dataset_merged.datasets:
    dataset_merged_len += set.__len__()
    tmp = set.filter(0)
    no_cactus_merged_len += tmp.__len__()

print("\n**** AFTER ****")
print("Number of no cactus images: " + str(no_cactus_merged_len))
print("Number of cactus images: " + str(dataset_merged_len - no_cactus_merged_len))

In [15]:
# --------- CREATING THE DATA LOADER AND TRAIN/VAL SPLIT ---------
torch.manual_seed(42)
np.random.seed(42)

dataset_used = dataset_merged
train_size = int(0.8 * len(dataset_used))
print("Train size: " + str(train_size))
test_size = len(dataset_used) - train_size
print("Test size: " + str(test_size))
train_dataset, val_dataset = torch.utils.data.random_split(dataset_used, [train_size, test_size])

# train_dataset = ConcatDataset([train_dataset, added_dataset])

# Retrieve the lengths of the datasets
print("dataset length: " + str(len(dataset_used)))
print("train length:" + str(len(train_dataset)))
print("val length:" + str(len(val_dataset)))


# Create DataLoader for training and validation
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True,pin_memory=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=True,pin_memory=True)

#print a sample
_, image,label=dataset_used.__getitem__(0)
print("Image shape: " + str(image.shape))
print("Label shape: " + str(label))


 # Phase 2: defining the model

In [16]:
# !! COMMENT THIS CELL IF NOT USING APPLE SILICON CHIP !!

# Metal Performance Shaders Acceleration
if torch.backends.mps.is_available():
    device = torch.device("mps")
    x = torch.ones(1, device=device)
    print (x)
else:
    print ("MPS device not found.")

model = LeNet5()
model.to(device)
print(model)

In [17]:
log = True

config = {
    "architecture": "LeNet5-notransf-notest",
    "dataset": "Cactus",
    "epochs": 20,
    "learning_rate": 0.001,
    "batch_size": 32,
    "momentum": 0.9
}

if log:
    # setting wandb
    wandb.login()

In [18]:
# --------- TRAINING ---------
if log:
    # start a new wandb run to track this script
    wandb.init(
        # set the wandb project where this run will be logged
        project = "Challenge_1",
    
        # track hyperparameters and run metadata
        config = {
            "architecture": config["architecture"],
            "dataset": config["dataset"],
            "epochs": config["epochs"],
            "learning_rate": config["learning_rate"],
            "batch_size": config["batch_size"],
            "momentum": config["momentum"]
        }
    )
    
    train_dataloader = DataLoader(train_dataset, batch_size=wandb.config.batch_size, shuffle=True,pin_memory=True)
    val_dataloader = DataLoader(val_dataset, batch_size=wandb.config.batch_size, shuffle=True,pin_memory=True)
    if not os.path.exists('./weights/lenet5_model.pth'):
        model.train_model(train_dataloader, val_dataloader, epochs=wandb.config.epochs, lr=wandb.config.learning_rate, device=device, wandb=wandb, freeze=False)
    model.train_model(train_dataloader, val_dataloader, epochs=wandb.config.epochs, lr=wandb.config.learning_rate, device=device, wandb=wandb)
    
    wandb.finish()
else:
    train_dataloader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True,pin_memory=True)
    val_dataloader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=True,pin_memory=True)
    if not os.path.exists('./weights/lenet5_model.pth'):
        model.train_model(train_dataloader, val_dataloader, epochs=config["epochs"], lr=config["learning_rate"], device=device, freeze=False)
    model.train_model(train_dataloader, val_dataloader, epochs=config["epochs"], lr=config["learning_rate"], device=device)