<a href="https://colab.research.google.com/github/YolandaMDavis/NSSADNN_IQA/blob/wildtrack-iqa/wildtrack_multitask_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#Only needed to copy data to local drive can be skipped if zip file is already available in working folder
import shutil
from zipfile import ZipFile

# mount google drive
from google.colab import drive
drive.mount('/content/drive') # for google colab. adjust accordingly
PARENT_DIR = '/content/drive/MyDrive/Wildtrack Group/IQA' 

# copy and extract tar file
shutil.copy(PARENT_DIR + '/data/WildTrack_Raw.zip', 'WildTrack_Raw.zip')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


'WildTrack_Raw.zip'

In [2]:
# Clone repo and copy in images. change working directory to repo's wildtrack branch

!git clone https://github.com/YolandaMDavis/NSSADNN_IQA.git
!mv WildTrack_Raw.zip NSSADNN_IQA/.
%cd "NSSADNN_IQA"
!git checkout wildtrack-iqa

with ZipFile('WildTrack_Raw.zip', 'r') as zipObj:
   # Extract all the contents of zip file in current directory
   zipObj.extractall()


Cloning into 'NSSADNN_IQA'...
remote: Enumerating objects: 61, done.[K
remote: Counting objects: 100% (61/61), done.[K
remote: Compressing objects: 100% (52/52), done.[K
remote: Total 61 (delta 29), reused 28 (delta 8), pack-reused 0[K
Unpacking objects: 100% (61/61), done.
/content/NSSADNN_IQA
Branch 'wildtrack-iqa' set up to track remote branch 'wildtrack-iqa' from 'origin'.
Switched to a new branch 'wildtrack-iqa'


In [1]:
# director variables
root_dir = '/content/NSSADNN_IQA'
data_dir = root_dir + '/RAW'
image_reference_file_suffix = '_image_references.csv'
%cd "NSSADNN_IQA"

/content/NSSADNN_IQA


In [2]:
import os
import csv
import random

import torch
import yaml

from torch.utils.data import Dataset

def generate_data_files(sample_percentage=1):

    image_reference_list = []

    subdirectories = list(os.walk(data_dir, topdown=False))[:-1]
    for subdir in subdirectories:
        image_location = subdir[0]
        images = subdir[2]
        species_rating = image_location.rsplit('/', 1)[-1].replace('_', ' ')
        score = int(species_rating.rsplit(' ', 1)[-1])
        species_class = species_rating.rsplit(' ', 1)[:-1][0]
        if len(species_class.rsplit(' ', 1)) > 1:
            species = species_class.rsplit(' ')[0]
            animal_class = ' '.join(species_class.rsplit(' ')[1:])
        else:
            animal_class = 'Unknown'
            species = species_class

        for image in images:
            image_reference = (image_location, species, animal_class, image, score)
            image_reference_list.append(image_reference)

    # shuffle then split
    seed = 1234
    random.Random(seed).shuffle(image_reference_list)
    training = image_reference_list[:int(len(image_reference_list) * 0.6 * sample_percentage)]
    validation = image_reference_list[-int(len(image_reference_list) * 0.2 * sample_percentage):]
    testing = image_reference_list[-int(len(image_reference_list) * 0.2 * sample_percentage):]

    # generated reference splits
    for dataset in [('training', training), ('validation', validation), ('testing', testing)]:
        ref_file_name = root_dir + '/' + dataset[0] + image_reference_file_suffix
        with open(ref_file_name, 'w', newline='') as csvfile:
            image_ref_writer = csv.writer(csvfile, delimiter=',', quotechar='"', quoting=csv.QUOTE_NONNUMERIC)
            image_ref_writer.writerows(dataset[1])


In [3]:
# take a percentage of the full data set as a training/test/validation sample
sample_size=.05
generate_data_files(sample_size)

In [4]:
import numpy as np
from scipy import stats
import torch.nn as nn
import random
from network import NSSADNN
from WildTrackDataset import WildTrackDataset

In [5]:
save_model = root_dir + "/model.pth"

seed = random.randint(10000000, 99999999)

torch.manual_seed(seed)
np.random.seed(seed)
print("seed:", seed)

config = {}
config["patch_size"] = 32
config["stride"] = 16

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dataset = WildTrackDataset(root_dir + '/' + 'training' + image_reference_file_suffix, config, "train")
val_dataset = WildTrackDataset(root_dir + '/' + 'validation' + image_reference_file_suffix, config, "validation")
test_dataset = WildTrackDataset(root_dir + '/' + 'testing' + image_reference_file_suffix, config, "testing")

seed: 90076636
Processing file number:0
Processing file number:1
Processing file number:2
Processing file number:3
Processing file number:4
Processing file number:5
Processing file number:6
Processing file number:7
Processing file number:8
Processing file number:9
Processing file number:10
Processing file number:11
Processing file number:12
Processing file number:13
Processing file number:14
Processing file number:15
Processing file number:16
Processing file number:17
Processing file number:18
Processing file number:19
Processing file number:20
Processing file number:21
Processing file number:22
Processing file number:23
Processing file number:24
Processing file number:25
Processing file number:26
Processing file number:27
Processing file number:28
Processing file number:29
Processing file number:30
Processing file number:31
Processing file number:32
Processing file number:33
Processing file number:34
Processing file number:35
Processing file number:36
Processing file number:37
Process

In [8]:
batch_size = 128
weight_decay=0.0001
epochs = 10
lr = 0.001

train_loader = torch.utils.data.DataLoader(train_dataset,
                                            batch_size=batch_size,
                                            shuffle=True,
                                            pin_memory=True,
                                            num_workers=0)

val_loader = torch.utils.data.DataLoader(val_dataset)
valnum = val_dataset.row_count

test_loader = torch.utils.data.DataLoader(test_dataset)
testnum = test_dataset.row_count

model = NSSADNN().to(device)
criterion = nn.L1Loss()
classify_criterion = nn.NLLLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay, momentum=0.9)
torch.optim.lr_scheduler.StepLR(optimizer, 750, gamma=0.1, last_epoch=-1)
best_SROCC = -1

# training 
for epoch in range(epochs):
      # train
      model.train()
      LOSS_all = 0
      LOSS_NSS = 0
      LOSS_q = 0
      for i, (patches, (label, features, species)) in enumerate(train_loader):
          patches = patches.to(device)
          label = label.to(device)
          features = features.to(device).float()
          species = species.to(device)

          optimizer.zero_grad()
          results = model(patches)
          outputs_q = results[0]
          outputs_NSS = results[1]
          outputs_species = results[2]

          loss_NSS = criterion(outputs_NSS, features)
          loss_q = criterion(outputs_q, label)
          loss_c = classify_criterion(outputs_species, species)
          loss = loss_NSS + loss_q

          loss.backward()
          optimizer.step()
          LOSS_all += float(loss.item())
          LOSS_NSS += float(loss_NSS.item())
          LOSS_q += float(loss_q.item())
      train_loss_all = LOSS_all / (i + 1)
      train_loss_NSS = LOSS_NSS / (i + 1)
      train_loss_q = LOSS_q / (i + 1)

      # val
      y_pred = np.zeros(valnum)
      y_val = np.zeros(valnum)
      model.eval()
      L = 0
      with torch.no_grad():
          for i, (patches, (label, features, species)) in enumerate(val_loader):
              y_val[i] = label.item()
              patches = patches.to(device)
              label = label.to(device)
              outputs_q = model(patches)[0]
              score = outputs_q.mean()
              y_pred[i] = score
              loss = criterion(score, label[0])
              L = L + loss.item()
      val_loss = L / (i + 1)

      val_SROCC = stats.spearmanr(y_pred, y_val)[0]
      val_PLCC = stats.pearsonr(y_pred, y_val)[0]
      val_KROCC = stats.stats.kendalltau(y_pred, y_val)[0]
      val_RMSE = np.sqrt(((y_pred - y_val) ** 2).mean())


      print("Epoch {} Valid Results: loss={:.3f} SROCC={:.3f} PLCC={:.3f} KROCC={:.3f} RMSE={:.3f}".format(epoch,
                                                                                                            val_loss,
                                                                                                            val_SROCC,
                                                                                                            val_PLCC,
                                                                                                            val_KROCC,
                                                                                                            val_RMSE))

      if val_SROCC > best_SROCC and epoch > 100:
          print("Update Epoch {} best valid SROCC".format(epoch))
          print("Valid Results: loss={:.3f} SROCC={:.3f} PLCC={:.3f} KROCC={:.3f} RMSE={:.3f}".format(val_loss,
                                                                                                      val_SROCC,
                                                                                                      val_PLCC,
                                                                                                      val_KROCC,
                                                                                                      val_RMSE))

          torch.save(model.state_dict(), save_model)
          best_SROCC = val_SROCC


Epoch 0 Valid Results: loss=1.145 SROCC=0.029 PLCC=-0.081 KROCC=0.020 RMSE=1.360
Epoch 1 Valid Results: loss=1.328 SROCC=0.027 PLCC=-0.056 KROCC=0.014 RMSE=1.565
Epoch 2 Valid Results: loss=1.078 SROCC=0.037 PLCC=-0.028 KROCC=0.024 RMSE=1.286
Epoch 3 Valid Results: loss=1.005 SROCC=0.004 PLCC=-0.054 KROCC=-0.012 RMSE=1.229
Epoch 4 Valid Results: loss=1.164 SROCC=0.018 PLCC=-0.017 KROCC=-0.002 RMSE=1.389
Epoch 5 Valid Results: loss=1.085 SROCC=0.044 PLCC=0.106 KROCC=0.020 RMSE=1.266
Epoch 6 Valid Results: loss=1.027 SROCC=-0.009 PLCC=0.086 KROCC=-0.018 RMSE=1.206
Epoch 7 Valid Results: loss=1.109 SROCC=0.035 PLCC=0.101 KROCC=0.020 RMSE=1.313
Epoch 8 Valid Results: loss=1.008 SROCC=0.037 PLCC=0.163 KROCC=0.028 RMSE=1.192
Epoch 9 Valid Results: loss=1.051 SROCC=0.044 PLCC=0.179 KROCC=0.036 RMSE=1.220


In [21]:
# final test
torch.save(model.state_dict(), save_model)

model.load_state_dict(torch.load(save_model))
model.eval()
with torch.no_grad():
    y_pred = np.zeros(testnum)
    y_test = np.zeros(testnum)
    L = 0
    for i, (patches, (label, features)) in enumerate(test_loader):
        y_test[i] = label.item()
        patches = patches.to(device)
        label = label.to(device)

        outputs = model(patches)[0]
        score = outputs.mean()

        y_pred[i] = score
        loss = criterion(score, label[0])
        L = L + loss.item()
test_loss = L / (i + 1)
SROCC = stats.spearmanr(y_pred, y_test)[0]
PLCC = stats.pearsonr(y_pred, y_test)[0]
KROCC = stats.stats.kendalltau(y_pred, y_test)[0]
RMSE = np.sqrt(((y_pred - y_test) ** 2).mean())

print("Final test Results: loss={:.3f} SROCC={:.3f} PLCC={:.3f} KROCC={:.3f} RMSE={:.3f}".format(test_loss,
                                                                                                  SROCC,
                                                                                                  PLCC,
                                                                                                  KROCC,
                                                                                                  RMSE))

Final test Results: loss=0.894 SROCC=0.419 PLCC=0.394 KROCC=0.314 RMSE=1.083


In [22]:
y_test

array([3., 3., 3., 3., 4., 5., 5., 2., 5., 4., 5., 5., 2., 5., 2., 4., 3.,
       5., 4., 4., 4., 5., 3., 3., 4., 5., 3., 4., 4., 4., 5., 5., 3., 4.,
       5., 5., 4., 2., 5., 4., 4., 5., 4., 5., 2., 4., 3., 5., 4., 3., 5.,
       4., 2., 3., 3., 4., 3., 3., 4., 5., 4., 2., 2., 3., 4., 5., 5., 3.,
       2., 5., 5., 4., 4., 3., 5., 4., 2., 4., 4., 3., 3., 4., 4., 3., 5.,
       2., 2., 4., 5., 2., 3., 1., 5., 3., 1., 2., 4., 4., 4., 3., 3., 5.,
       4., 3., 3., 3., 5., 3., 3., 5., 5., 4., 5., 4., 5., 3., 1., 4., 3.,
       5., 3., 4., 3., 4., 4., 4., 4., 4., 3., 2., 3., 2., 4., 4., 2., 3.,
       4., 5., 4., 3., 4., 2., 4., 3., 4., 2., 3., 5., 3., 2., 2., 4., 5.,
       4., 3., 5., 3., 4., 4., 4., 2., 3., 4., 3., 2., 1., 4., 4., 5., 3.,
       2., 3., 3., 3., 5., 5., 5., 5., 3., 5., 5., 4., 4., 2., 3., 5., 3.,
       2., 5., 2., 5., 5., 4., 3., 2., 5., 5., 5., 4., 4., 5., 3., 5., 5.,
       3., 5., 4., 4., 4., 2., 2., 3., 5., 1., 4., 5., 4., 2., 3., 5., 5.,
       1., 5., 2., 5., 4.

In [23]:
y_pred

array([3.41472077, 3.26770091, 3.72750759, 3.69558382, 3.96568322,
       3.7688179 , 3.90765738, 3.58774161, 3.98070312, 3.81571412,
       3.87146568, 3.86927581, 3.95050621, 3.79937816, 3.73624849,
       3.65594864, 3.69446135, 3.75809216, 3.79607916, 3.81485724,
       4.0171175 , 4.06911278, 3.73132324, 3.56478953, 4.36066103,
       4.05292845, 3.73699284, 3.69815445, 3.8316505 , 3.52409196,
       3.65771174, 3.89356661, 3.52557898, 4.09866047, 4.15722322,
       3.75652814, 3.74555779, 3.73519683, 4.15567493, 3.83318019,
       3.87520266, 3.86182022, 3.82869911, 3.92059422, 3.18037224,
       3.7416358 , 3.61219788, 3.85446739, 3.75874472, 4.0548439 ,
       3.94978094, 3.69557452, 3.74291563, 3.85957766, 3.74868298,
       3.99885321, 3.73605227, 3.25028276, 4.17590714, 3.93321729,
       3.93958354, 3.17193747, 3.61127806, 3.84204268, 3.97924995,
       3.81976461, 3.91324043, 3.52556133, 3.92824244, 4.08065939,
       3.78168774, 3.64512205, 3.90056062, 3.95519495, 4.06641