In [1]:
import vgg_class
from data import DatasetManager, TransformerManager
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from sklearn.model_selection import StratifiedShuffleSplit
from torch.distributed.fsdp import FullyShardedDataParallel, CPUOffload
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from PIL import Image
import re
from pathlib import Path
import pandas as pd
import os
import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
from loguru import logger


In [2]:
data_standard_transforms = {
    'train': transforms.Compose([transforms.Resize((224,224)),
        transforms.ToTensor()
    ]),
    'test': transforms.Compose([transforms.Resize((224,224)),
        transforms.ToTensor()
    ])
}

In [3]:
#train_dir = "/scratch/braines/Dataset/CCMT-Dataset-Augmented/train_Data/Cashew/"
#test_dir = "/scratch/braines/Dataset/CCMT-Dataset-Augmented/test_data/Cashew/"
root_dir = "/scratch/braines/Dataset/CCMT-Dataset-Augmented/"
full_set = DatasetManager(root_dir, transform=data_standard_transforms)

In [4]:
NUM_EPOCHS = 30  # Number of passes through entire training dataset
CV_FOLDS = 2  # Number of cross-validation folds
BATCH_SIZE = 128  # Within each epoch data is split into batches
LEARNING_RATE = 0.001
VAL_SPLIT = 0.1
CROSS_VALIDATE = True

device_ids = [i for i in range(torch.cuda.device_count())]
vgg = vgg_class.vgg16((len(full_set.unique_crops), len(full_set.unique_states)))
model = nn.DataParallel(vgg, device_ids=device_ids)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion_crop = nn.CrossEntropyLoss()
criterion_state = nn.CrossEntropyLoss()
optimiser = optim.Adam(model.parameters(), lr=LEARNING_RATE)
model = model.to(device)

In [6]:

# Define cross-validation iterator
#skf = StratifiedShuffleSplit(n_splits=CV_FOLDS, test_size=VAL_SPLIT, random_state=42)

# Determine the number of splits
#n_splits = skf.get_n_splits(full_set.samples, full_set.targets)

train_loss_crop = []
train_loss_state = []
val_loss_crop = []
val_loss_state = []
train_accuracy_crop = []
train_accuracy_state = []
val_accuracy_crop = []
val_accuracy_state = []
train_total = 0
epoch_stats = []
# Training Loop
for epoch in range(NUM_EPOCHS):
    model.train()
    running_crop_loss = 0.0
    running_state_loss = 0.0
    val_running_crop_loss = 0.0
    val_running_state_loss = 0.0
    crop_correct = 0
    val_crop_correct = 0
    val_state_correct = 0
    state_correct = 0
    total = 0
    
    # Initialize tqdm for epoch progress
    epoch_progress = tqdm(total=NUM_EPOCHS, desc=f'Epoch {epoch + 1}/{NUM_EPOCHS}')
    '''
    for fold_idx, (train_idx, val_idx) in enumerate(skf.split(train_set.samples, train_set.targets)):
        train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
        val_sampler = torch.utils.data.SubsetRandomSampler(val_idx)
            
    
        val_loader = torch.utils.data.DataLoader(full_set.train_samples, batch_size=BATCH_SIZE, sampler=val_sampler, pin_memory=True)
    '''
    train, valid = random_split(full_set.train_samples, [1-VAL_SPLIT, VAL_SPLIT])
    train_loader = DataLoader(train, batch_size=BATCH_SIZE ,pin_memory= True)
    valid_loader = DataLoader(valid, batch_size=BATCH_SIZE ,pin_memory= True)
            # Initialize tqdm for fold progress

    for batch_idx, batch in enumerate(train_loader):
        crop_label_idx = batch['crop_idx']
        img_paths = batch['img_path']
        splits = batch['split']
        state_label_idx = batch['state_idx']
        images = []
        for path, split in zip(img_paths, splits):
            images.append(full_set.load_image_from_path(path, split))

            #fold_progress = tqdm(total=len(train_loader), desc=f'Fold {batch_idx + 1}/{len(train_loader)}', leave=False)

        images_tensor = torch.stack(images, dim=0)
            #batch_metrics = train_batch(batch_idx, images_tensor, crop_label_idx, state_label_idx)
        inputs = images_tensor.clone().detach().requires_grad_(True)
        crop_labels = crop_label_idx.clone().detach()
        state_labels = state_label_idx.clone().detach()

        inputs = inputs.to(device)
        crop_labels = crop_labels.to(device)
        state_labels = state_labels.to(device)
                # Zero the parameter gradients
        optimiser.zero_grad(set_to_none=True)
                
                # Forward pass
        crop_outputs, state_outputs = model(inputs)
            
            #crop_outputs = model_outputs[:, :len(full_set.unique_crops)]
            #state_outputs = model_outputs[:, len(full_set.unique_states):]

                # Calculate loss
        crop_loss = criterion_crop(crop_outputs, crop_labels)
        state_loss = criterion_state(state_outputs, state_labels)
            
        total_loss = crop_loss + state_loss
            #running_loss = running_loss + crop_loss + state_loss    
                # Backward pass and optimize
            #crop_loss.backward(retain_graph=True)
            #state_loss.backward()
            

        _, predicted_crop = torch.max(crop_outputs, 1)
        _, predicted_state = torch.max(state_outputs, 1)
                        
        crop_correct += (predicted_crop == crop_labels).sum().item()
        state_correct += (predicted_state == state_labels).sum().item()
        total += len(batch)

        running_crop_loss+= crop_loss.item()
        running_state_loss += state_loss.item()

        crop_loss.backward(retain_graph=True)
        state_loss.backward()
        #total_loss.backward()
            
        optimiser.step()
                # Update tqdm progress description at batch level
            #fold_progress.set_postfix(loss=running_loss / (batch_idx + 1), accuracy=100. * correct / total)
            #fold_progress.set_postfix({})
            #fold_progress.update(1)

            # Close fold progress bar
            #fold_progress.close()

            # Update tqdm progress at fold level
            #if batch_idx == 1:
            #    break
    model.eval()
    for batch_idx, batch in enumerate(valid_loader):
        crop_label_idx = batch['crop_idx']
        img_paths = batch['img_path']
        splits = batch['split']
        state_label_idx = batch['state_idx']
        images = []
        for path, split in zip(img_paths, splits):
            images.append(full_set.load_image_from_path(path, split))

        images_tensor = torch.stack(images, dim=0)
            #batch_metrics = train_batch(batch_idx, images_tensor, crop_label_idx, state_label_idx)
        inputs = images_tensor.clone().detach().requires_grad_(True)
        crop_labels = crop_label_idx.clone().detach()
        state_labels = state_label_idx.clone().detach()

        inputs = inputs.to(device)
        crop_labels = crop_labels.to(device)
        state_labels = state_labels.to(device)

        crop_outputs, state_outputs = model(inputs)

        val_crop_loss = criterion_crop(crop_outputs, crop_labels)
        val_state_loss = criterion_state(state_outputs, state_labels)

        _, val_predicted_crop = torch.max(crop_outputs, 1)
        _, val_predicted_state = torch.max(state_outputs, 1)

        val_crop_correct += (val_predicted_crop == crop_labels).sum().item()
        val_state_correct += (val_predicted_state == state_labels).sum().item()
        total += len(batch)

        val_running_crop_loss+= crop_loss.item()
        val_running_state_loss += state_loss.item()

        epoch_progress.update(1)

    # Close epoch progress bar
    epoch_progress.close()

    torch.cuda.empty_cache()

    # Calculate epoch-level metrics
    epoch_crop_loss = running_crop_loss / len(train_loader)
    epoch_state_loss = running_state_loss / len(train_loader)
    epoch_crop_accuracy = 100. * crop_correct / len(train)
    epoch_state_accuracy = 100. * state_correct / len(train)
    epoch_val_crop_loss = val_running_crop_loss / len(valid_loader)
    epoch_val_state_loss = val_running_state_loss / len(valid_loader)
    epoch_val_crop_accuracy = 100. * val_crop_correct / len(valid)
    epoch_val_state_accuracy = 100. * val_state_correct / len(valid)

    # Log metrics
    train_loss_crop.append(epoch_crop_loss)
    train_loss_state.append(epoch_state_loss)
    train_accuracy_crop.append(epoch_crop_accuracy)
    train_accuracy_state.append(epoch_state_accuracy)
    val_loss_crop.append(epoch_val_crop_loss)
    val_loss_state.append(epoch_val_state_loss)
    val_accuracy_crop.append(epoch_val_crop_accuracy)
    val_accuracy_state.append(epoch_val_state_accuracy)
    print("Epoch: ", epoch + 1)
    print("Train - Crop Loss: {}, State Loss: {}, Crop Accuracy: {}, State Accuracy: {}".format(train_loss_crop[epoch], train_loss_state[epoch], train_accuracy_crop[epoch], train_accuracy_state[epoch]))
    print("Validate - Crop Loss: {}, State Loss: {}, Crop Accuracy: {}, State Accuracy: {}".format(val_loss_crop[epoch], val_loss_state[epoch], val_accuracy_crop[epoch], val_accuracy_state[epoch]))


Epoch 3/30:   0%|          | 0/30 [02:12<?, ?it/s]


Epoch 1/30: 63it [14:51, 14.16s/it]


Epoch:  1
Train - Crop Loss: 0.927922600033009, State Loss: 2.286305971061234, Crop Accuracy: 60.41055312551907, State Accuracy: 28.546315264935497
Validate - Crop Loss: 1.083845615386963, State Loss: 2.321791172027588, Crop Accuracy: 63.23657655412981, State Accuracy: 31.593372368257132


Epoch 2/30: 63it [14:24, 13.72s/it]                          


Epoch:  2
Train - Crop Loss: 0.8170614701456729, State Loss: 2.126362614294069, Crop Accuracy: 67.02840374287138, State Accuracy: 32.45944299872654
Validate - Crop Loss: 0.7233901619911194, State Loss: 1.9474072456359863, Crop Accuracy: 69.81437647938209, State Accuracy: 36.726049582658526


Epoch 3/30: 63it [14:24, 13.72s/it]                          


Epoch:  3
Train - Crop Loss: 0.7134476766122126, State Loss: 1.9593547185965343, Crop Accuracy: 72.03504789325065, State Accuracy: 36.66186811361497
Validate - Crop Loss: 0.7408238053321838, State Loss: 1.8223164081573486, Crop Accuracy: 74.47365142643578, State Accuracy: 41.79643702504049


Epoch 4/30: 63it [14:24, 13.72s/it]                          


Epoch:  4
Train - Crop Loss: 0.5959273016558284, State Loss: 1.7417457894941346, Crop Accuracy: 77.64935496373401, State Accuracy: 43.01533691379215
Validate - Crop Loss: 0.729225754737854, State Loss: 1.763702392578125, Crop Accuracy: 73.10327644200822, State Accuracy: 44.823719945185


Epoch 5/30: 63it [14:26, 13.76s/it]                          


Epoch:  5
Train - Crop Loss: 0.4578528957820572, State Loss: 1.478653730122389, Crop Accuracy: 83.41315541775096, State Accuracy: 50.71286196777587
Validate - Crop Loss: 0.5135307312011719, State Loss: 1.4237726926803589, Crop Accuracy: 86.47066151737884, State Accuracy: 56.04833686308708


Epoch 6/30: 63it [14:22, 13.69s/it]                          


Epoch:  6
Train - Crop Loss: 0.3483661773985466, State Loss: 1.27450183872628, Crop Accuracy: 87.72770057028957, State Accuracy: 56.88915342450584
Validate - Crop Loss: 0.2884375751018524, State Loss: 1.277924656867981, Crop Accuracy: 88.33935467796188, State Accuracy: 61.74162202566339


Epoch 7/30: 63it [14:24, 13.71s/it]                          


Epoch:  7
Train - Crop Loss: 0.26668982690414494, State Loss: 1.1274892781688048, Crop Accuracy: 90.74248380488345, State Accuracy: 61.102652123359725
Validate - Crop Loss: 0.3660416305065155, State Loss: 0.9272240996360779, Crop Accuracy: 91.15485237324032, State Accuracy: 63.69752086707363


Epoch 8/30: 63it [14:21, 13.67s/it]                          


Epoch:  8
Train - Crop Loss: 0.21340127605780035, State Loss: 1.0092492262874029, Crop Accuracy: 92.6789768008416, State Accuracy: 64.61297824040751
Validate - Crop Loss: 0.16356699168682098, State Loss: 0.9788539409637451, Crop Accuracy: 92.5875171296873, State Accuracy: 66.96150492089198


Epoch 9/30: 63it [14:24, 13.73s/it]                          


Epoch:  9
Train - Crop Loss: 0.17708331791568646, State Loss: 0.9183808947031477, Crop Accuracy: 94.05348541055312, State Accuracy: 67.49072587342893
Validate - Crop Loss: 0.10138192772865295, State Loss: 0.7523965239524841, Crop Accuracy: 89.37336489348449, State Accuracy: 66.0894481126199


Epoch 10/30: 63it [14:21, 13.67s/it]                          


Epoch:  10
Train - Crop Loss: 0.1485565971220489, State Loss: 0.8407211533689921, Crop Accuracy: 94.86877803000941, State Accuracy: 70.270195448757
Validate - Crop Loss: 0.23962868750095367, State Loss: 0.9683148860931396, Crop Accuracy: 95.14139778248412, State Accuracy: 72.76691167310328


Epoch 11/30: 63it [14:20, 13.66s/it]                          


Epoch:  11
Train - Crop Loss: 0.1311324439810968, State Loss: 0.7810638992132338, Crop Accuracy: 95.62593433364708, State Accuracy: 72.56934831958364
Validate - Crop Loss: 0.28549662232398987, State Loss: 1.1863150596618652, Crop Accuracy: 95.42793073377351, State Accuracy: 73.61405257256759


Epoch 12/30: 63it [14:22, 13.69s/it]                          


Epoch:  12
Train - Crop Loss: 0.10576773248283208, State Loss: 0.7207192855598652, Crop Accuracy: 96.47306350700404, State Accuracy: 74.17917058856099
Validate - Crop Loss: 0.18655888736248016, State Loss: 1.0162231922149658, Crop Accuracy: 94.35654665503924, State Accuracy: 73.90058552385699


Epoch 13/30: 63it [14:18, 13.63s/it]                          


Epoch:  13
Train - Crop Loss: 0.0958072442602597, State Loss: 0.669869138076242, Crop Accuracy: 96.81911300592436, State Accuracy: 76.2665411660484
Validate - Crop Loss: 0.06745331734418869, State Loss: 0.48356327414512634, Crop Accuracy: 92.3134421328018, State Accuracy: 75.19621278186122


Epoch 14/30: 63it [14:16, 13.59s/it]                          


Epoch:  14
Train - Crop Loss: 0.08224779903987604, State Loss: 0.6119066838142091, Crop Accuracy: 97.26344056253807, State Accuracy: 78.11998228226565
Validate - Crop Loss: 0.16302570700645447, State Loss: 0.8290718793869019, Crop Accuracy: 97.65790457206927, State Accuracy: 79.74336613927993


Epoch 15/30: 63it [14:20, 13.66s/it]                          


Epoch:  15
Train - Crop Loss: 0.06957509984155144, State Loss: 0.5716160736252777, Crop Accuracy: 97.69946293117768, State Accuracy: 79.6273738995626
Validate - Crop Loss: 0.16126424074172974, State Loss: 0.481723815202713, Crop Accuracy: 97.12221253270214, State Accuracy: 80.26660022424318


Epoch 16/30: 63it [14:29, 13.80s/it]                          


Epoch:  16
Train - Crop Loss: 0.0640770833733269, State Loss: 0.5431169963515965, Crop Accuracy: 97.8974032445601, State Accuracy: 80.62399645645313
Validate - Crop Loss: 0.013813738711178303, State Loss: 0.7008064985275269, Crop Accuracy: 96.22523981562227, State Accuracy: 81.27569453095802


Epoch 17/30: 63it [14:24, 13.72s/it]                          


Epoch:  17
Train - Crop Loss: 0.05280443345798196, State Loss: 0.49150826261106845, Crop Accuracy: 98.23930014949339, State Accuracy: 82.37915951497702
Validate - Crop Loss: 0.02114492654800415, State Loss: 0.47163376212120056, Crop Accuracy: 98.3929238818986, State Accuracy: 85.4740251650679


Epoch 18/30: 63it [14:23, 13.70s/it]                          


Epoch:  18
Train - Crop Loss: 0.04637019256690303, State Loss: 0.455460947541009, Crop Accuracy: 98.53275012457782, State Accuracy: 83.58894856320248
Validate - Crop Loss: 0.04937303811311722, State Loss: 0.4832668900489807, Crop Accuracy: 98.1063909306092, State Accuracy: 85.6484365267223


Epoch 19/30: 63it [14:21, 13.68s/it]                          


Epoch:  19
Train - Crop Loss: 0.041342056147441006, State Loss: 0.4225125308585378, Crop Accuracy: 98.65732794418913, State Accuracy: 84.70322794972593
Validate - Crop Loss: 0.01961478590965271, State Loss: 0.3747110664844513, Crop Accuracy: 94.95452846642581, State Accuracy: 83.50566836925377


Epoch 20/30: 63it [14:22, 13.69s/it]                          


Epoch:  20
Train - Crop Loss: 0.05992392793800517, State Loss: 0.4530438459552495, Crop Accuracy: 98.04689662809368, State Accuracy: 83.92669287414871
Validate - Crop Loss: 0.07284712791442871, State Loss: 0.22173482179641724, Crop Accuracy: 99.37710227980565, State Accuracy: 90.10838420331382


Epoch 21/30: 63it [14:25, 13.74s/it]                          


Epoch:  21
Train - Crop Loss: 0.035656505408985295, State Loss: 0.36272027801095913, Crop Accuracy: 98.85942085155861, State Accuracy: 86.89441337688943
Validate - Crop Loss: 0.03556832671165466, State Loss: 0.3061521053314209, Crop Accuracy: 99.40201818861343, State Accuracy: 89.85922511523607


Epoch 22/30: 63it [14:20, 13.67s/it]                          


Epoch:  22
Train - Crop Loss: 0.028727186435062967, State Loss: 0.33033929720389105, Crop Accuracy: 99.10857649078125, State Accuracy: 88.2218592547478
Validate - Crop Loss: 0.14149852097034454, State Loss: 0.38790586590766907, Crop Accuracy: 99.56397159586396, State Accuracy: 91.19222623645197


Epoch 23/30: 63it [14:24, 13.73s/it]                          


Epoch:  23
Train - Crop Loss: 0.057048513115367204, State Loss: 0.39336901133039354, Crop Accuracy: 98.171474447705, State Accuracy: 86.35319196057804
Validate - Crop Loss: 0.007656214293092489, State Loss: 0.3346252739429474, Crop Accuracy: 99.19023296374735, State Accuracy: 92.25115236078236


Epoch 24/30: 63it [14:22, 13.69s/it]                          


Epoch:  24
Train - Crop Loss: 0.02557766399443133, State Loss: 0.2888088541616381, Crop Accuracy: 99.18193898455235, State Accuracy: 89.65311998228226
Validate - Crop Loss: 0.008745172992348671, State Loss: 0.5424212217330933, Crop Accuracy: 99.26498069017067, State Accuracy: 93.37236825713218


Epoch 25/30: 63it [14:26, 13.75s/it]                          


Epoch:  25
Train - Crop Loss: 0.0240809016005678, State Loss: 0.26321737813738594, Crop Accuracy: 99.24007530037096, State Accuracy: 90.69957366701733
Validate - Crop Loss: 0.05922875180840492, State Loss: 0.5781629085540771, Crop Accuracy: 99.5141397782484, State Accuracy: 93.83331257007599


Epoch 26/30: 63it [14:25, 13.74s/it]                          


Epoch:  26
Train - Crop Loss: 0.02941962128275371, State Loss: 0.2665592099317407, Crop Accuracy: 99.08227672886329, State Accuracy: 90.78816233874093
Validate - Crop Loss: 0.10333777964115143, State Loss: 0.46111059188842773, Crop Accuracy: 99.2400647813629, State Accuracy: 94.3067148374237


Epoch 27/30: 63it [14:24, 13.72s/it]                          


Epoch:  27
Train - Crop Loss: 0.02238238673653058, State Loss: 0.22811759199980086, Crop Accuracy: 99.27606444825868, State Accuracy: 91.99103039698798
Validate - Crop Loss: 0.02159169688820839, State Loss: 0.41418853402137756, Crop Accuracy: 99.73838295751838, State Accuracy: 95.41547277936962


Epoch 28/30: 63it [14:18, 13.62s/it]                          


Epoch:  28
Train - Crop Loss: 0.020527247825585432, State Loss: 0.20713606360739312, Crop Accuracy: 99.33143236808593, State Accuracy: 92.84369636232766
Validate - Crop Loss: 0.007107607088983059, State Loss: 0.10762706398963928, Crop Accuracy: 99.66363523109506, State Accuracy: 95.09156596486856


Epoch 29/30: 63it [14:26, 13.75s/it]                          


Epoch:  29
Train - Crop Loss: 0.018889584143484316, State Loss: 0.20153098953640566, Crop Accuracy: 99.40756325784841, State Accuracy: 92.96965837993467
Validate - Crop Loss: 0.00191740901209414, State Loss: 0.048616524785757065, Crop Accuracy: 98.3929238818986, State Accuracy: 93.58415348199826


Epoch 30/30: 63it [14:18, 13.63s/it]                          

Epoch:  30
Train - Crop Loss: 0.019456920039297113, State Loss: 0.19268680536641483, Crop Accuracy: 99.38403189192182, State Accuracy: 93.47489064835834
Validate - Crop Loss: 0.007226305082440376, State Loss: 0.05732180178165436, Crop Accuracy: 99.88787841036502, State Accuracy: 96.98517503425937





In [7]:
torch.cuda.empty_cache()

In [12]:
print("Crop correct: ", crop_correct)
print("State correct: ", state_correct)
print("Total: ", len(full_set.train_samples))
print("crop_accuracy: ", 100. * crop_correct / len(full_set.train_samples))

Crop correct:  78774
State correct:  27017
Total:  80271
crop_accuracy:  98.13506745898269


In [13]:
y = 2
print("y+6")

y+6


In [14]:
train_loader.

{'cashew': 18910, 'cassava': 20212, 'maize': 19426, 'tomato': 21723}

In [8]:
for i in range(NUM_EPOCHS):
    print("Epoch: ", i+1)
    print("Crop Loss: ", train_loss_crop[i])
    print("State Loss: ", train_loss_state[i])
    print("Crop Accuracy: ", train_accuracy_crop[i])
    print("State Accuracy: ", train_accuracy_state[i])
    print("Validation Crop Loss: ", val_loss_crop[i])
    print("Validation States Loss: ", val_loss_state[i])
    print("Validation Crop Accuracy: ", val_accuracy_crop[i])
    print("Validation States Accuracy: ", val_accuracy_state[i])


Epoch:  1
Crop Loss:  1.2459352704276025
State Loss:  2.749915881283515
Crop Accuracy:  50.53983721831571
State Accuracy:  20.99136260450695
Validation Crop Loss:  1.061446189880371
Validation States Loss:  2.699982166290283
Validation Crop Accuracy:  58.37797433661393
Validation States Accuracy:  28.80279058178647
Epoch:  2
Crop Loss:  0.9305367541524161
State Loss:  2.2799721538493065
Crop Accuracy:  60.845191296163
State Accuracy:  28.95465367366148
Validation Crop Loss:  0.9835046529769897
Validation States Loss:  2.45900559425354
Validation Crop Accuracy:  64.74398903700012
Validation States Accuracy:  28.877538308209793
Epoch:  3
Crop Loss:  0.8069555794243264
State Loss:  2.0620343174554607
Crop Accuracy:  67.39244781573557
State Accuracy:  34.64509163390731
Validation Crop Loss:  0.7115528583526611
Validation States Loss:  1.7276532649993896
Validation Crop Accuracy:  57.61803911797683
Validation States Accuracy:  31.207175781736638
Epoch:  4
Crop Loss:  0.675626997610109
State

In [7]:
total = 0
crop_correct = 0
state_correct = 0
model.eval()
crop_output_list = []
state_output_list = []
         #torch.utils.data.DataLoader(full_set.train_samples, batch_size=BATCH_SIZE)
testing = torch.utils.data.DataLoader(full_set.test_samples)

for batch_idx, batch in enumerate(testing):
    crop_label_idx = batch['crop_idx']
    img_paths = batch['img_path']
    splits = batch['split']
    state_label_idx = batch['state_idx']
    images = []
    for path, split in zip(img_paths, splits):
        images.append(full_set.load_image_from_path(path, split))

    images_tensor = torch.stack(images, dim=0)
    #batch_metrics = train_batch(batch_idx, images_tensor, crop_label_idx, state_label_idx)
    inputs = images_tensor.clone().detach().requires_grad_(True)
    crop_labels = crop_label_idx.clone().detach()
    state_labels = state_label_idx.clone().detach()

    inputs = inputs.to(device)
    crop_labels = crop_labels.to(device)
    state_labels = state_labels.to(device)
            # Zero the parameter gradients
    optimiser.zero_grad()
            
            # Forward pass
    crop_outputs, state_outputs = model(inputs)

    _, crop_predicted = crop_outputs.max(1)
    _, state_predicted = state_outputs.max(1)
    total += crop_labels.size(0)
    crop_correct += crop_predicted.eq(crop_labels).sum().item()
    state_correct += state_predicted.eq(state_labels).sum().item()

    crop_output_list.append({"Path":path,
                             "Label":crop_labels,
                             "Predicted":crop_predicted,
                             "Output":torch.nn.functional.softmax(crop_outputs)})
    
    state_output_list.append({"Path":path,
                             "Label":state_labels,
                             "Predicted":state_predicted,
                             "Output":torch.nn.functional.softmax(state_outputs)})
    
#with torch.no_grad():
#    pred = model(test_set)

In [8]:
crop_accuracy = 100. * crop_correct / len(full_set.test_samples)
state_accuracy = 100. * state_correct / len(full_set.test_samples)
print("State Test Accuracy: ", state_accuracy)
print("Crop Test Accuracy: ", crop_accuracy)

State Test Accuracy:  91.11324606701093
Crop Test Accuracy:  98.59893519074497


In [29]:
#torch.nn.functional.softmax(crop_outputs)
len(full_set.test_samples)

24981

In [20]:
crop_label_test = []
state_label_test = []
testing = torch.utils.data.DataLoader(full_set.test_samples)
count = 0
for batch_idx, batch in enumerate(testing):
        
        crop_label_idx = batch['crop_idx']
        if crop_label_idx not in crop_label_test:
                crop_label_test.append(crop_label_idx)
        img_paths = batch['img_path']
        
        splits = batch['split']
        state_label_idx = batch['state_idx']
        if state_label_idx not in state_label_test:
                state_label_test.append(state_label_idx)
        
        if batch_idx == 16000:
                print(img_paths)
                print(crop_label_idx)
                print(crop_label_idx.clone().detach())
                print(full_set.unique_crops)
                print(state_label_idx)
                print(full_set.unique_states)
                
        

['/scratch/braines/Dataset/CCMT-Dataset-Augmented/test_data/Maize/leaf beetle/714maize_valid_leaf beetle.JPG']
tensor([3])
tensor([3])
['tomato', 'cassava', 'cashew', 'maize']
tensor([4])
['leaf spot', 'streak virus', 'leaf curl', 'mosaic', 'leaf beetle', 'verticulium wilt', 'anthracnose', 'grasshoper', 'bspot', 'cashew healthy', 'bb', 'cassava healthy', 'septoria leaf spot', 'tomato healthy', 'gumosis', 'leaf miner', 'farmyw', 'maize healthy', 'leaf blight', 'gmite', 'red rust']


In [17]:
len(crop_label_test)

4

In [22]:
torch.save(model.state_dict(), "/scratch/braines/vgg16/best_model_30_epoch.pth")

In [39]:
def find_classes(directory: Union[str, Path]) -> Tuple[List[str], Dict[str, int]]:
    """Finds the class folders in a dataset.

    See :class:`DatasetFolder` for details.
    """
    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
    if not classes:
        raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx

In [40]:
x, y = find_classes(test_dir)

In [63]:
def online_mean_and_sd(loader):
    """Compute the mean and sd in an online fashion

        Var[x] = E[X^2] - E^2[X]
    """
    cnt = 0
    fst_moment = torch.empty(3)
    snd_moment = torch.empty(3)

    for data in loader:

        b, c, h, w = data.shape
        nb_pixels = b * h * w
        sum_ = torch.sum(data, dim=[0, 2, 3])
        sum_of_square = torch.sum(data ** 2, dim=[0, 2, 3])
        fst_moment = (cnt * fst_moment + sum_) / (cnt + nb_pixels)
        snd_moment = (cnt * snd_moment + sum_of_square) / (cnt + nb_pixels)

        cnt += nb_pixels

    return fst_moment, torch.sqrt(snd_moment - fst_moment ** 2)

In [65]:
normal_set = torch.utils.data.DataLoader(full_set.test_sample, batch_size=1, shuffle=False)

mean = 0.
std = 0.
for images, _ in normal_set:
    batch_samples = images.size(0) # batch size (the last batch can have smaller size!)
    images = images.view(batch_samples, images.size(1), -1)
    mean += images.mean(2).sum(0)
    std += images.std(2).sum(0)

mean /= len(normal_set.dataset)
std /= len(normal_set.dataset)

In [66]:
print("mean: ", mean)
print("Std: ", std)

mean:  tensor([0.4851, 0.5189, 0.3830])
Std:  tensor([0.2000, 0.1880, 0.2216])
