In [1]:
import vgg_class
from data import DatasetManager, TransformerManager
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from sklearn.model_selection import StratifiedShuffleSplit
from torch.distributed.fsdp import FullyShardedDataParallel, CPUOffload
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from PIL import Image
import re
from pathlib import Path
import pandas as pd
import os
import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
from loguru import logger


In [2]:
data_standard_transforms = {
    'train': transforms.Compose([transforms.Resize((224,224)),
        transforms.ToTensor()
    ]),
    'test': transforms.Compose([transforms.Resize((224,224)),
        transforms.ToTensor()
    ])
}

In [3]:
#train_dir = "/scratch/braines/Dataset/CCMT-Dataset-Augmented/train_Data/Cashew/"
#test_dir = "/scratch/braines/Dataset/CCMT-Dataset-Augmented/test_data/Cashew/"
root_dir = "/scratch/braines/Dataset/CCMT-Dataset-Augmented/"
full_set = DatasetManager(root_dir, transform=data_standard_transforms)

In [5]:
NUM_EPOCHS = 5  # Number of passes through entire training dataset
CV_FOLDS = 2  # Number of cross-validation folds
BATCH_SIZE = 128  # Within each epoch data is split into batches
LEARNING_RATE = 0.001
VAL_SPLIT = 0.1
CROSS_VALIDATE = True

device_ids = [i for i in range(torch.cuda.device_count())]
vgg = vgg_class.vgg16((len(full_set.unique_crops), len(full_set.unique_states)))
model = nn.DataParallel(vgg, device_ids=device_ids)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion_crop = nn.CrossEntropyLoss()
criterion_state = nn.CrossEntropyLoss()
optimiser = optim.Adam(model.parameters(), lr=LEARNING_RATE)
model = model.to(device)

In [6]:

# Define cross-validation iterator
#skf = StratifiedShuffleSplit(n_splits=CV_FOLDS, test_size=VAL_SPLIT, random_state=42)

# Determine the number of splits
#n_splits = skf.get_n_splits(full_set.samples, full_set.targets)

train_loss_crop = []
train_loss_state = []
val_loss_crop = []
val_loss_state = []
train_accuracy_crop = []
train_accuracy_state = []
val_accuracy_crop = []
val_accuracy_state = []
train_total = 0
epoch_stats = []
# Training Loop
for epoch in range(NUM_EPOCHS):
    model.train()
    running_crop_loss = 0.0
    running_state_loss = 0.0
    val_running_crop_loss = 0.0
    val_running_state_loss = 0.0
    crop_correct = 0
    val_crop_correct = 0
    val_state_correct = 0
    state_correct = 0
    total = 0
    
    # Initialize tqdm for epoch progress
    epoch_progress = tqdm(total=NUM_EPOCHS, desc=f'Epoch {epoch + 1}/{NUM_EPOCHS}')
    '''
    for fold_idx, (train_idx, val_idx) in enumerate(skf.split(train_set.samples, train_set.targets)):
        train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
        val_sampler = torch.utils.data.SubsetRandomSampler(val_idx)
            
    
        val_loader = torch.utils.data.DataLoader(full_set.train_samples, batch_size=BATCH_SIZE, sampler=val_sampler, pin_memory=True)
    '''
    train, valid = random_split(full_set.train_samples, [1-VAL_SPLIT, VAL_SPLIT])
    train_loader = DataLoader(train, batch_size=BATCH_SIZE ,pin_memory= True)
    valid_loader = DataLoader(valid, batch_size=BATCH_SIZE ,pin_memory= True)
            # Initialize tqdm for fold progress

    for batch_idx, batch in enumerate(train_loader):
        crop_label_idx = batch['crop_idx']
        img_paths = batch['img_path']
        splits = batch['split']
        state_label_idx = batch['state_idx']
        images = []
        for path, split in zip(img_paths, splits):
            images.append(full_set.load_image_from_path(path, split))

            #fold_progress = tqdm(total=len(train_loader), desc=f'Fold {batch_idx + 1}/{len(train_loader)}', leave=False)

        images_tensor = torch.stack(images, dim=0)
            #batch_metrics = train_batch(batch_idx, images_tensor, crop_label_idx, state_label_idx)
        inputs = images_tensor.clone().detach().requires_grad_(True)
        crop_labels = crop_label_idx.clone().detach()
        state_labels = state_label_idx.clone().detach()

        inputs = inputs.to(device)
        crop_labels = crop_labels.to(device)
        state_labels = state_labels.to(device)
                # Zero the parameter gradients
        optimiser.zero_grad(set_to_none=True)
                
                # Forward pass
        crop_outputs, state_outputs = model(inputs)
            
            #crop_outputs = model_outputs[:, :len(full_set.unique_crops)]
            #state_outputs = model_outputs[:, len(full_set.unique_states):]

                # Calculate loss
        crop_loss = criterion_crop(crop_outputs, crop_labels)
        state_loss = criterion_state(state_outputs, state_labels)
            
        total_loss = crop_loss + state_loss
            #running_loss = running_loss + crop_loss + state_loss    
                # Backward pass and optimize
            #crop_loss.backward(retain_graph=True)
            #state_loss.backward()
            

        _, predicted_crop = torch.max(crop_outputs, 1)
        _, predicted_state = torch.max(state_outputs, 1)
                        
        crop_correct += (predicted_crop == crop_labels).sum().item()
        state_correct += (predicted_state == state_labels).sum().item()
        total += len(batch)

        running_crop_loss+= crop_loss.item()
        running_state_loss += state_loss.item()

        total_loss.backward()
            
        optimiser.step()
                # Update tqdm progress description at batch level
            #fold_progress.set_postfix(loss=running_loss / (batch_idx + 1), accuracy=100. * correct / total)
            #fold_progress.set_postfix({})
            #fold_progress.update(1)

            # Close fold progress bar
            #fold_progress.close()

            # Update tqdm progress at fold level
            #if batch_idx == 1:
            #    break
    model.eval()
    for batch_idx, batch in enumerate(valid_loader):
        crop_label_idx = batch['crop_idx']
        img_paths = batch['img_path']
        splits = batch['split']
        state_label_idx = batch['state_idx']
        images = []
        for path, split in zip(img_paths, splits):
            images.append(full_set.load_image_from_path(path, split))

        images_tensor = torch.stack(images, dim=0)
            #batch_metrics = train_batch(batch_idx, images_tensor, crop_label_idx, state_label_idx)
        inputs = images_tensor.clone().detach().requires_grad_(True)
        crop_labels = crop_label_idx.clone().detach()
        state_labels = state_label_idx.clone().detach()

        inputs = inputs.to(device)
        crop_labels = crop_labels.to(device)
        state_labels = state_labels.to(device)

        crop_outputs, state_outputs = model(inputs)

        val_crop_loss = criterion_crop(crop_outputs, crop_labels)
        val_state_loss = criterion_state(state_outputs, state_labels)

        _, val_predicted_crop = torch.max(crop_outputs, 1)
        _, val_predicted_state = torch.max(state_outputs, 1)

        val_crop_correct += (val_predicted_crop == crop_labels).sum().item()
        val_state_correct += (val_predicted_state == state_labels).sum().item()
        total += len(batch)

        val_running_crop_loss+= crop_loss.item()
        val_running_state_loss += state_loss.item()

        epoch_progress.update(1)

    # Close epoch progress bar
    epoch_progress.close()

    torch.cuda.empty_cache()

    # Calculate epoch-level metrics
    epoch_crop_loss = running_crop_loss / len(train_loader)
    epoch_state_loss = running_state_loss / len(train_loader)
    epoch_crop_accuracy = 100. * crop_correct / len(train)
    epoch_state_accuracy = 100. * state_correct / len(train)
    epoch_val_crop_loss = val_running_crop_loss / len(valid_loader)
    epoch_val_state_loss = val_running_state_loss / len(valid_loader)
    epoch_val_crop_accuracy = 100. * val_crop_correct / len(valid)
    epoch_val_state_accuracy = 100. * val_state_correct / len(valid)

    # Log metrics
    train_loss_crop.append(epoch_crop_loss)
    train_loss_state.append(epoch_state_loss)
    train_accuracy_crop.append(epoch_crop_accuracy)
    train_accuracy_state.append(epoch_state_accuracy)
    val_loss_crop.append(epoch_val_crop_loss)
    val_loss_state.append(epoch_val_state_loss)
    val_accuracy_crop.append(epoch_val_crop_accuracy)
    val_accuracy_state.append(epoch_val_state_accuracy)
    print("Epoch: ", epoch + 1)
    print("Train - Crop Loss: {}, State Loss: {}, Crop Accuracy: {}, State Accuracy: {}".format(train_loss_crop, train_loss_state, train_accuracy_crop, train_accuracy_state))
    print("Validate - Crop Loss: {}, State Loss: {}, Crop Accuracy: {}, State Accuracy: {}".format(val_loss_crop, val_loss_state, val_accuracy_crop, val_accuracy_state))


  with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):
Epoch 1/5: 63it [11:59, 11.42s/it]                       
Epoch 2/5: 63it [11:46, 11.22s/it]                       
Epoch 3/5: 63it [11:34, 11.02s/it]                       
Epoch 4/5: 63it [11:11, 10.66s/it]                       
Epoch 5/5: 63it [11:36, 11.06s/it]                       


In [7]:
torch.cuda.empty_cache()

In [12]:
print("Crop correct: ", crop_correct)
print("State correct: ", state_correct)
print("Total: ", len(full_set.train_samples))
print("crop_accuracy: ", 100. * crop_correct / len(full_set.train_samples))

Crop correct:  78774
State correct:  27017
Total:  80271
crop_accuracy:  98.13506745898269


In [13]:
y = 2
print("y+6")

y+6


In [8]:
for i in range(NUM_EPOCHS):
    print("Epoch: ", i+1)
    print("Crop Loss: ", train_loss_crop[i])
    print("State Loss: ", train_loss_state[i])
    print("Crop Accuracy: ", train_accuracy_crop[i])
    print("State Accuracy: ", train_accuracy_state[i])
    print("Validation Crop Loss: ", val_loss_crop[i])
    print("Validation States Loss: ", val_loss_state[i])
    print("Validation Crop Accuracy: ", val_accuracy_crop[i])
    print("Validation States Accuracy: ", val_accuracy_state[i])


Epoch:  1
Crop Loss:  1.2459352704276025
State Loss:  2.749915881283515
Crop Accuracy:  50.53983721831571
State Accuracy:  20.99136260450695
Validation Crop Loss:  1.061446189880371
Validation States Loss:  2.699982166290283
Validation Crop Accuracy:  58.37797433661393
Validation States Accuracy:  28.80279058178647
Epoch:  2
Crop Loss:  0.9305367541524161
State Loss:  2.2799721538493065
Crop Accuracy:  60.845191296163
State Accuracy:  28.95465367366148
Validation Crop Loss:  0.9835046529769897
Validation States Loss:  2.45900559425354
Validation Crop Accuracy:  64.74398903700012
Validation States Accuracy:  28.877538308209793
Epoch:  3
Crop Loss:  0.8069555794243264
State Loss:  2.0620343174554607
Crop Accuracy:  67.39244781573557
State Accuracy:  34.64509163390731
Validation Crop Loss:  0.7115528583526611
Validation States Loss:  1.7276532649993896
Validation Crop Accuracy:  57.61803911797683
Validation States Accuracy:  31.207175781736638
Epoch:  4
Crop Loss:  0.675626997610109
State

In [9]:
total = 0
crop_correct = 0
state_correct = 0
model.eval()

         #torch.utils.data.DataLoader(full_set.train_samples, batch_size=BATCH_SIZE)
testing = torch.utils.data.DataLoader(full_set.test_samples)

for batch_idx, batch in enumerate(testing):
    crop_label_idx = batch['crop_idx']
    img_paths = batch['img_path']
    splits = batch['split']
    state_label_idx = batch['state_idx']
    images = []
    for path, split in zip(img_paths, splits):
        images.append(full_set.load_image_from_path(path, split))

    images_tensor = torch.stack(images, dim=0)
    #batch_metrics = train_batch(batch_idx, images_tensor, crop_label_idx, state_label_idx)
    inputs = images_tensor.clone().detach().requires_grad_(True)
    crop_labels = crop_label_idx.clone().detach()
    state_labels = state_label_idx.clone().detach()

    inputs = inputs.to(device)
    crop_labels = crop_labels.to(device)
    state_labels = state_labels.to(device)
            # Zero the parameter gradients
    optimiser.zero_grad()
            
            # Forward pass
    crop_outputs, state_outputs = model(inputs)

    _, crop_predicted = crop_outputs.max(1)
    _, state_predicted = state_outputs.max(1)
    total += crop_labels.size(0)
    crop_correct += crop_predicted.eq(crop_labels).sum().item()
    state_correct += state_predicted.eq(state_labels).sum().item()

#with torch.no_grad():
#    pred = model(test_set)

  with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):


In [10]:
crop_accuracy = 100. * crop_correct / len(full_set.test_samples)
state_accuracy = 100. * state_correct / len(full_set.test_samples)
print("State Test Accuracy: ", state_accuracy)
print("Crop Test Accuracy: ", crop_accuracy)

State Test Accuracy:  41.22733277290741
Crop Test Accuracy:  71.81858212241303


In [27]:
crop_label_test = []
state_label_test = []
testing = torch.utils.data.DataLoader(full_set.test_samples)
count = 0
for batch_idx, batch in enumerate(testing):
        
        crop_label_idx = batch['crop_idx']
        if crop_label_idx not in crop_label_test:
                crop_label_test.append(crop_label_idx)
        img_paths = batch['img_path']
        
        splits = batch['split']
        state_label_idx = batch['state_idx']
        if state_label_idx not in state_label_test:
                state_label_test.append(state_label_idx)
        
        if batch_idx == 8000:
                print(img_paths)
                print(crop_label_idx)
                print(crop_label_idx.clone().detach())
                print(full_set.unique_crops)
                print(state_label_idx)
                print(full_set.unique_states)
                
        

['/scratch/braines/Dataset/CCMT-Dataset-Augmented/test_data/Cassava/bacterial blight/1991cassava_valid_bb.JPG']
tensor([1])
tensor([1])
['maize', 'cassava', 'cashew', 'tomato']
tensor([10])
['leaf miner', 'leaf curl', 'grasshoper', 'leaf spot', 'gmite', 'leaf blight', 'red rust', 'mosaic', 'verticulium wilt', 'streak virus', 'bb', 'farmyw', 'leaf beetle', 'gumosis', 'cashew healthy', 'cassava healthy', 'maize healthy', 'septoria leaf spot', 'tomato healthy', 'anthracnose', 'bspot']


In [17]:
len(crop_label_test)

4

In [14]:

crop_label_idx

tensor([3])

In [39]:
def find_classes(directory: Union[str, Path]) -> Tuple[List[str], Dict[str, int]]:
    """Finds the class folders in a dataset.

    See :class:`DatasetFolder` for details.
    """
    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
    if not classes:
        raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx

In [40]:
x, y = find_classes(test_dir)

In [63]:
def online_mean_and_sd(loader):
    """Compute the mean and sd in an online fashion

        Var[x] = E[X^2] - E^2[X]
    """
    cnt = 0
    fst_moment = torch.empty(3)
    snd_moment = torch.empty(3)

    for data in loader:

        b, c, h, w = data.shape
        nb_pixels = b * h * w
        sum_ = torch.sum(data, dim=[0, 2, 3])
        sum_of_square = torch.sum(data ** 2, dim=[0, 2, 3])
        fst_moment = (cnt * fst_moment + sum_) / (cnt + nb_pixels)
        snd_moment = (cnt * snd_moment + sum_of_square) / (cnt + nb_pixels)

        cnt += nb_pixels

    return fst_moment, torch.sqrt(snd_moment - fst_moment ** 2)

In [65]:
normal_set = torch.utils.data.DataLoader(full_set.test_sample, batch_size=1, shuffle=False)

mean = 0.
std = 0.
for images, _ in normal_set:
    batch_samples = images.size(0) # batch size (the last batch can have smaller size!)
    images = images.view(batch_samples, images.size(1), -1)
    mean += images.mean(2).sum(0)
    std += images.std(2).sum(0)

mean /= len(normal_set.dataset)
std /= len(normal_set.dataset)

In [66]:
print("mean: ", mean)
print("Std: ", std)

mean:  tensor([0.4851, 0.5189, 0.3830])
Std:  tensor([0.2000, 0.1880, 0.2216])
