In [1]:
import vgg_class
import torch
from torch import nn
import torch.optim as optim
from torchvision import transforms
from data import DatasetManager
from PIL import Image

In [2]:
data_standard_transforms = {
    'train': transforms.Compose([transforms.Resize((224,224)),
        transforms.ToTensor()
    ]),
    'test': transforms.Compose([transforms.Resize((224,224)),
        transforms.ToTensor()
    ])
}

root_dir = "/scratch/braines/Dataset/CCMT-Dataset-Augmented/train_Data/Cashew/"
Cashew_train_set = DatasetManager(root_dir, transform=data_standard_transforms)

In [19]:
NUM_EPOCHS = 1  # Number of passes through entire training dataset
CV_FOLDS = 2  # Number of cross-validation folds
BATCH_SIZE = 32  # Within each epoch data is split into batches
LEARNING_RATE = 0.001
VAL_SPLIT = 0.2
CROSS_VALIDATE = True

device_ids = [i for i in range(torch.cuda.device_count())]
vgg = vgg_class.vgg16((len(Cashew_train_set.unique_crops), len(Cashew_train_set.unique_states)))
model = nn.DataParallel(vgg, device_ids=device_ids)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion_crop = nn.CrossEntropyLoss()
criterion_state = nn.CrossEntropyLoss()
optimiser = optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [20]:
model.train()
train_loader = torch.utils.data.DataLoader(Cashew_train_set.train_samples, batch_size=BATCH_SIZE)

In [21]:
grad_crop_list = []
grad_state_list = []
for batch_idx, batch in enumerate(train_loader):
    crop_label_idx = batch['crop_idx']
    img_paths = batch['img_path']
    splits = batch['split']
    state_label_idx = batch['state_idx']
    images = []
    for path, split in zip(img_paths, splits):
        images.append(Cashew_train_set.load_image_from_path(path, split))

    images_tensor = torch.stack(images, dim=0)
    #batch_metrics = train_batch(batch_idx, images_tensor, crop_label_idx, state_label_idx)
    inputs = images_tensor.clone().detach().requires_grad_(True)
    crop_labels = crop_label_idx.clone().detach()
    state_labels = state_label_idx.clone().detach()

    inputs = inputs.to(device)
    crop_labels = crop_labels.to(device)
    state_labels = state_labels.to(device)
    #optimiser.zero_grad()
                
                # Forward pass
    crop_outputs, state_outputs = model(inputs)
    crop_loss = criterion_crop(crop_outputs, crop_labels)
    state_loss = criterion_state(state_outputs, state_labels)
    model.zero_grad()

    crop_loss.backward(retain_graph=True)

    grad_crop_out = {}
    for name, param in model.named_parameters():
        grad_crop_out[name] = param.grad.data.clone()
    grad_crop_list.append(grad_crop_out)
    state_loss.backward()

    grad_state_out = {}
    for name, param in model.named_parameters():
        grad_state_out[name] = param.grad.data.clone()
    grad_state_list.append(grad_state_out)
    
    optimiser.step()

    if batch_idx == 2:
        break

  with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):
