In [1]:
import vgg_class
from data import DatasetManager, TransformerManager
from torch.autograd import Variable
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from sklearn.model_selection import KFold
from torch.distributed.fsdp import FullyShardedDataParallel, CPUOffload
from tqdm import tqdm
from PIL import Image
import re
from pathlib import Path
import pandas as pd
import os
import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
from loguru import logger

In [2]:
data_standard_transforms = {
    "train": transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()]),
    "test": transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()]),
}

In [3]:
#train_dir = "/scratch/braines/Dataset/CCMT-Dataset-Augmented/train_Data/Cashew/"
#test_dir = "/scratch/braines/Dataset/CCMT-Dataset-Augmented/test_data/Cashew/"
root_dir = "/scratch/braines/Dataset/CCMT-Dataset-Augmented/"
full_set = DatasetManager(root_dir, transform=data_standard_transforms)

In [4]:
NUM_EPOCHS = 1  # Number of passes through entire training dataset
CV_FOLDS = 2  # Number of cross-validation folds
BATCH_SIZE = 64  # Within each epoch data is split into batches
LEARNING_RATE = 0.001
VAL_SPLIT = 0.2
CROSS_VALIDATE = True

device_ids = [i for i in range(torch.cuda.device_count())]
vgg = vgg_class.vgg16((len(full_set.unique_crops), len(full_set.unique_states)))
model = nn.DataParallel(vgg, device_ids=device_ids)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion_crop = nn.CrossEntropyLoss()
criterion_state = nn.CrossEntropyLoss()
optimiser = optim.Adam(model.parameters(), lr=LEARNING_RATE)


In [9]:

# Define cross-validation iterator
#skf = StratifiedShuffleSplit(n_splits=CV_FOLDS, test_size=VAL_SPLIT, random_state=42)

# Determine the number of splits
#n_splits = skf.get_n_splits(train_set.samples, train_set.targets)

train_loss_crop = []
train_loss_state = []
train_accuracy_crop = []
train_accuracy_state = []
train_total = 0
epoch_stats = []
# Training Loop
for epoch in range(NUM_EPOCHS):
    model.train()
    running_crop_loss = 0.0
    running_state_loss = 0.0
    crop_correct = 0
    state_correct = 0
    total = 0
    
    # Initialize tqdm for epoch progress
    epoch_progress = tqdm(total=NUM_EPOCHS, desc=f'Epoch {epoch + 1}/{NUM_EPOCHS}')
    '''
    for fold_idx, (train_idx, val_idx) in enumerate(skf.split(train_set.samples, train_set.targets)):
        train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
        val_sampler = torch.utils.data.SubsetRandomSampler(val_idx)
    '''        
    train_loader = torch.utils.data.DataLoader(full_set.train_samples, batch_size=BATCH_SIZE)
    #val_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, sampler=val_sampler)

        # Initialize tqdm for fold progress

    for batch_idx, batch in enumerate(train_loader):
        crop_label_idx = batch['crop_idx']
        img_paths = batch['img_path']
        splits = batch['split']
        state_label_idx = batch['state_idx']
        images = []
        for path, split in zip(img_paths, splits):
            images.append(full_set.load_image_from_path(path, split))

        #fold_progress = tqdm(total=len(train_loader), desc=f'Fold {batch_idx + 1}/{len(train_loader)}', leave=False)

        images_tensor = torch.stack(images, dim=0)
        #batch_metrics = train_batch(batch_idx, images_tensor, crop_label_idx, state_label_idx)
        inputs = images_tensor.clone().detach().requires_grad_(True)
        crop_labels = crop_label_idx.clone().detach()
        state_labels = state_label_idx.clone().detach()

        inputs = inputs.to(device)
        crop_labels = crop_labels.to(device)
        state_labels = state_labels.to(device)
            # Zero the parameter gradients
        optimiser.zero_grad()
            
            # Forward pass
        crop_outputs, state_outputs = model(inputs)
        
        #crop_outputs = model_outputs[:, :len(full_set.unique_crops)]
        #state_outputs = model_outputs[:, len(full_set.unique_states):]

            # Calculate loss
        crop_loss = criterion_crop(crop_outputs, crop_labels)
        state_loss = criterion_state(state_outputs, state_labels)
        
        #running_loss = running_loss + crop_loss + state_loss    
            # Backward pass and optimize
        crop_loss.backward(retain_graph=True)
        state_loss.backward()
        optimiser.step()

        _, predicted_crop = torch.max(crop_outputs, 1)
        _, predicted_state = torch.max(state_outputs, 1)
                    
        crop_correct += (predicted_crop == crop_labels).sum().item()
        state_correct += (predicted_state == state_labels).sum().item()
        total += crop_labels.size(0)

        running_crop_loss+= crop_loss.item()
        running_state_loss += state_loss.item()

            # Update tqdm progress description at batch level
        #fold_progress.set_postfix(loss=running_loss / (batch_idx + 1), accuracy=100. * correct / total)
        #fold_progress.set_postfix({})
        #fold_progress.update(1)

        # Close fold progress bar
        #fold_progress.close()

        # Update tqdm progress at fold level
        epoch_progress.update(1)

    # Close epoch progress bar
    epoch_progress.close()

    # Calculate epoch-level metrics
    epoch_crop_loss = running_crop_loss / len(train_loader)
    epoch_state_loss = running_state_loss / len(train_loader)
    epoch_crop_accuracy = 100. * crop_correct / total
    epoch_state_accuracy = 100. * state_correct / total

    # Log metrics
    train_loss_crop.append(epoch_crop_loss)
    train_loss_state.append(epoch_state_loss)
    train_accuracy_crop.append(epoch_crop_accuracy)
    train_accuracy_state.append(epoch_crop_accuracy)


Epoch 1/5:   0%|          | 0/5 [00:00<?, ?it/s]

  with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):
Epoch 1/5: 2509it [49:52,  1.19s/it]                    


NameError: name 'train_losses' is not defined

In [134]:
print("Correct Crop: ", train_correct_crop/train_total)
print("Correct State: ", train_correct_state/train_total)

Correct Crop:  0.9692628720210288
Correct State:  0.4776407419865207


In [135]:
total = 0
crop_correct = 0
state_correct = 0
model.eval()

         #torch.utils.data.DataLoader(full_set.train_samples, batch_size=BATCH_SIZE)
testing = torch.utils.data.DataLoader(full_set.test_samples)

for batch_idx, batch in enumerate(testing):
    crop_label_idx = batch['crop_idx']
    img_paths = batch['img_path']
    splits = batch['split']
    state_label_idx = batch['state_idx']
    images = []
    for path, split in zip(img_paths, splits):
        images.append(full_set.load_image_from_path(path, split))

    images_tensor = torch.stack(images, dim=0)
    #batch_metrics = train_batch(batch_idx, images_tensor, crop_label_idx, state_label_idx)
    inputs = images_tensor.clone().detach().requires_grad_(True)
    crop_labels = crop_label_idx.clone().detach()
    state_labels = state_label_idx.clone().detach()

    inputs = inputs.to(device)
    crop_labels = crop_labels.to(device)
    state_labels = state_labels.to(device)
            # Zero the parameter gradients
    optimiser.zero_grad()
            
            # Forward pass
    crop_outputs, state_outputs = model(inputs)

    _, crop_predicted = crop_outputs.max(1)
    _, state_predicted = state_outputs.max(1)
    total += crop_labels.size(0)
    crop_correct += crop_predicted.eq(crop_labels).sum().item()
    state_correct += state_predicted.eq(state_labels).sum().item()

#with torch.no_grad():
#    pred = model(test_set)

  with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):


In [136]:
crop_accuracy = 100. * crop_correct / total
state_accuracy = 100. * state_correct / total
print("State Test Accuracy: ", state_accuracy)
print("Crop Test Accuracy: ", crop_accuracy)

State Test Accuracy:  3.058324326488131
Crop Test Accuracy:  21.79656538969617


In [35]:
count = 0
for test_idx, (inputs, targets) in enumerate(testing):
    print(targets)
    count += 1
    if count > 5:
        break

tensor([0])
tensor([0])
tensor([0])
tensor([0])
tensor([0])
tensor([0])


In [39]:
def find_classes(directory: Union[str, Path]) -> Tuple[List[str], Dict[str, int]]:
    """Finds the class folders in a dataset.

    See :class:`DatasetFolder` for details.
    """
    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
    if not classes:
        raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx

In [40]:
x, y = find_classes(test_dir)

In [63]:
def online_mean_and_sd(loader):
    """Compute the mean and sd in an online fashion

        Var[x] = E[X^2] - E^2[X]
    """
    cnt = 0
    fst_moment = torch.empty(3)
    snd_moment = torch.empty(3)

    for data in loader:

        b, c, h, w = data.shape
        nb_pixels = b * h * w
        sum_ = torch.sum(data, dim=[0, 2, 3])
        sum_of_square = torch.sum(data ** 2, dim=[0, 2, 3])
        fst_moment = (cnt * fst_moment + sum_) / (cnt + nb_pixels)
        snd_moment = (cnt * snd_moment + sum_of_square) / (cnt + nb_pixels)

        cnt += nb_pixels

    return fst_moment, torch.sqrt(snd_moment - fst_moment ** 2)

In [65]:
normal_set = torch.utils.data.DataLoader(full_set.test_sample, batch_size=1, shuffle=False)

mean = 0.
std = 0.
for images, _ in normal_set:
    batch_samples = images.size(0) # batch size (the last batch can have smaller size!)
    images = images.view(batch_samples, images.size(1), -1)
    mean += images.mean(2).sum(0)
    std += images.std(2).sum(0)

mean /= len(normal_set.dataset)
std /= len(normal_set.dataset)

In [66]:
print("mean: ", mean)
print("Std: ", std)

mean:  tensor([0.4851, 0.5189, 0.3830])
Std:  tensor([0.2000, 0.1880, 0.2216])
