<a href="https://colab.research.google.com/github/YolandaMDavis/NSSADNN_IQA/blob/wildtrack-iqa/wildtrack_multitask_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#Only needed to copy data to local drive can be skipped if zip file is already available in working folder
import shutil
from zipfile import ZipFile

# mount google drive
from google.colab import drive
drive.mount('/content/drive') # for google colab. adjust accordingly
PARENT_DIR = '/content/drive/MyDrive/Wildtrack Group/IQA' 

# copy and extract tar file
shutil.copy(PARENT_DIR + '/data/WildTrack_Raw.zip', 'WildTrack_Raw.zip')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
# Clone repo and copy in images. change working directory to repo's wildtrack branch

!git clone https://github.com/YolandaMDavis/NSSADNN_IQA.git
!mv WildTrack_Raw.zip NSSADNN_IQA/.
%cd "NSSADNN_IQA"
!git checkout wildtrack-iqa

with ZipFile('WildTrack_Raw.zip', 'r') as zipObj:
   # Extract all the contents of zip file in current directory
   zipObj.extractall()


Cloning into 'NSSADNN_IQA'...
remote: Enumerating objects: 58, done.[K
remote: Counting objects: 100% (58/58), done.[K
remote: Compressing objects: 100% (49/49), done.[K
remote: Total 58 (delta 27), reused 29 (delta 8), pack-reused 0[K
Unpacking objects: 100% (58/58), done.
/content/NSSADNN_IQA
Branch 'wildtrack-iqa' set up to track remote branch 'wildtrack-iqa' from 'origin'.
Switched to a new branch 'wildtrack-iqa'


In [5]:
# director variables
root_dir = '/content/NSSADNN_IQA'
data_dir = root_dir + '/RAW'
image_reference_file_suffix = '_image_references.csv'

In [6]:
import os
import csv
import random

import torch
import yaml

from torch.utils.data import Dataset

def generate_data_files(sample_percentage=1):

    image_reference_list = []

    subdirectories = list(os.walk(data_dir, topdown=False))[:-1]
    for subdir in subdirectories:
        image_location = subdir[0]
        images = subdir[2]
        species_rating = image_location.rsplit('/', 1)[-1].replace('_', ' ')
        score = int(species_rating.rsplit(' ', 1)[-1])
        species_class = species_rating.rsplit(' ', 1)[:-1][0]
        if len(species_class.rsplit(' ', 1)) > 1:
            species = species_class.rsplit(' ')[0]
            animal_class = ' '.join(species_class.rsplit(' ')[1:])
        else:
            animal_class = 'Unknown'
            species = species_class

        for image in images:
            image_reference = (image_location, species, animal_class, image, score)
            image_reference_list.append(image_reference)

    # shuffle then split
    seed = 1234
    random.Random(seed).shuffle(image_reference_list)
    training = image_reference_list[:int(len(image_reference_list) * 0.6 * sample_percentage)]
    validation = image_reference_list[-int(len(image_reference_list) * 0.2 * sample_percentage):]
    testing = image_reference_list[-int(len(image_reference_list) * 0.2 * sample_percentage):]

    # generated reference splits
    for dataset in [('training', training), ('validation', validation), ('testing', testing)]:
        ref_file_name = root_dir + '/' + dataset[0] + image_reference_file_suffix
        with open(ref_file_name, 'w', newline='') as csvfile:
            image_ref_writer = csv.writer(csvfile, delimiter=',', quotechar='"', quoting=csv.QUOTE_NONNUMERIC)
            image_ref_writer.writerows(dataset[1])


In [10]:
# take a percentage of the full data set as a training/test/validation sample
sample_size=.1
generate_data_files(sample_size)

In [19]:
import numpy as np
from scipy import stats
import torch.nn as nn
import random
from network import NSSADNN
from WildTrackDataset import WildTrackDataset

In [26]:
save_model = root_dir + "/model.pth"

seed = random.randint(10000000, 99999999)

torch.manual_seed(seed)
np.random.seed(seed)
print("seed:", seed)

config = {}
config["patch_size"] = 32
config["stride"] = 16

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dataset = WildTrackDataset(root_dir + '/' + 'training' + image_reference_file_suffix, config, "train")
val_dataset = WildTrackDataset(root_dir + '/' + 'validation' + image_reference_file_suffix, config, "validation")
test_dataset = WildTrackDataset(root_dir + '/' + 'testing' + image_reference_file_suffix, config, "testing")

seed: 97693607
Processing file number:0
Processing file number:1
Processing file number:2
Processing file number:3
Processing file number:4
Processing file number:5
Processing file number:6
Processing file number:7
Processing file number:8
Processing file number:9
Processing file number:10
Processing file number:11
Processing file number:12
Processing file number:13
Processing file number:14
Processing file number:15
Processing file number:16
Processing file number:17
Processing file number:18
Processing file number:19
Processing file number:20
Processing file number:21
Processing file number:22
Processing file number:23
Processing file number:24
Processing file number:25
Processing file number:26
Processing file number:27
Processing file number:28
Processing file number:29
Processing file number:30
Processing file number:31
Processing file number:32
Processing file number:33
Processing file number:34
Processing file number:35
Processing file number:36
Processing file number:37
Process

In [33]:
batch_size = 32
weight_decay=0.0001
epochs = 20
lr = 0.001

train_loader = torch.utils.data.DataLoader(train_dataset,
                                            batch_size=batch_size,
                                            shuffle=True,
                                            pin_memory=True,
                                            num_workers=0)

val_loader = torch.utils.data.DataLoader(val_dataset)
valnum = val_dataset.row_count

test_loader = torch.utils.data.DataLoader(test_dataset)
testnum = test_dataset.row_count


model = NSSADNN().to(device)
criterion = nn.L1Loss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay, momentum=0.9)
torch.optim.lr_scheduler.StepLR(optimizer, 750, gamma=0.1, last_epoch=-1)
best_SROCC = -1

# training 
for epoch in range(epochs):
      # train
      model.train()
      LOSS_all = 0
      LOSS_NSS = 0
      LOSS_q = 0
      for i, (patches, (label, features)) in enumerate(train_loader):
          patches = patches.to(device)
          label = label.to(device)
          features = features.to(device).float()

          optimizer.zero_grad()
          outputs_q = model(patches)[0]
          outputs_NSS = model(patches)[1]

          loss_NSS = criterion(outputs_NSS, features)
          loss_q = criterion(outputs_q, label)
          loss = loss_NSS + loss_q

          loss.backward()
          optimizer.step()
          LOSS_all += float(loss.item())
          LOSS_NSS += float(loss_NSS.item())
          LOSS_q += float(loss_q.item())
      train_loss_all = LOSS_all / (i + 1)
      train_loss_NSS = LOSS_NSS / (i + 1)
      train_loss_q = LOSS_q / (i + 1)

      # val
      y_pred = np.zeros(valnum)
      y_val = np.zeros(valnum)
      model.eval()
      L = 0
      with torch.no_grad():
          for i, (patches, (label, features)) in enumerate(val_loader):
              y_val[i] = label.item()
              patches = patches.to(device)
              label = label.to(device)
              outputs_q = model(patches)[0]
              score = outputs_q.mean()
              y_pred[i] = score
              loss = criterion(score, label[0])
              L = L + loss.item()
      val_loss = L / (i + 1)

      val_SROCC = stats.spearmanr(y_pred, y_val)[0]
      val_PLCC = stats.pearsonr(y_pred, y_val)[0]
      val_KROCC = stats.stats.kendalltau(y_pred, y_val)[0]
      val_RMSE = np.sqrt(((y_pred - y_val) ** 2).mean())


      print("Epoch {} Valid Results: loss={:.3f} SROCC={:.3f} PLCC={:.3f} KROCC={:.3f} RMSE={:.3f}".format(epoch,
                                                                                                            val_loss,
                                                                                                            val_SROCC,
                                                                                                            val_PLCC,
                                                                                                            val_KROCC,
                                                                                                            val_RMSE))

      if val_SROCC > best_SROCC and epoch > 100:
          print("Update Epoch {} best valid SROCC".format(epoch))
          print("Valid Results: loss={:.3f} SROCC={:.3f} PLCC={:.3f} KROCC={:.3f} RMSE={:.3f}".format(val_loss,
                                                                                                      val_SROCC,
                                                                                                      val_PLCC,
                                                                                                      val_KROCC,
                                                                                                      val_RMSE))

          torch.save(model.state_dict(), save_model)
          best_SROCC = val_SROCC


Epoch 0 Valid Results: loss=1.100 SROCC=0.051 PLCC=-0.011 KROCC=0.036 RMSE=1.326
Epoch 1 Valid Results: loss=0.941 SROCC=0.011 PLCC=0.054 KROCC=0.005 RMSE=1.204
Epoch 2 Valid Results: loss=0.953 SROCC=-0.191 PLCC=-0.052 KROCC=-0.143 RMSE=1.173
Epoch 3 Valid Results: loss=0.936 SROCC=-0.007 PLCC=0.108 KROCC=-0.009 RMSE=1.181
Epoch 4 Valid Results: loss=0.928 SROCC=-0.005 PLCC=0.126 KROCC=-0.004 RMSE=1.199
Epoch 5 Valid Results: loss=0.931 SROCC=0.094 PLCC=0.238 KROCC=0.064 RMSE=1.191
Epoch 6 Valid Results: loss=0.934 SROCC=0.114 PLCC=0.158 KROCC=0.088 RMSE=1.188
Epoch 7 Valid Results: loss=0.930 SROCC=0.184 PLCC=0.256 KROCC=0.133 RMSE=1.193
Epoch 8 Valid Results: loss=0.930 SROCC=0.087 PLCC=0.164 KROCC=0.066 RMSE=1.196
Epoch 9 Valid Results: loss=0.926 SROCC=0.283 PLCC=0.361 KROCC=0.211 RMSE=1.197
Epoch 10 Valid Results: loss=0.928 SROCC=0.027 PLCC=0.118 KROCC=0.016 RMSE=1.201
Epoch 11 Valid Results: loss=0.927 SROCC=0.271 PLCC=0.332 KROCC=0.203 RMSE=1.196
Epoch 12 Valid Results: loss=0

In [34]:
# final test
torch.save(model.state_dict(), save_model)

model.load_state_dict(torch.load(save_model))
model.eval()
with torch.no_grad():
    y_pred = np.zeros(testnum)
    y_test = np.zeros(testnum)
    L = 0
    for i, (patches, (label, features)) in enumerate(test_loader):
        y_test[i] = label.item()
        patches = patches.to(device)
        label = label.to(device)

        outputs = model(patches)[0]
        score = outputs.mean()

        y_pred[i] = score
        loss = criterion(score, label[0])
        L = L + loss.item()
test_loss = L / (i + 1)
SROCC = stats.spearmanr(y_pred, y_test)[0]
PLCC = stats.pearsonr(y_pred, y_test)[0]
KROCC = stats.stats.kendalltau(y_pred, y_test)[0]
RMSE = np.sqrt(((y_pred - y_test) ** 2).mean())

print("Final test Results: loss={:.3f} SROCC={:.3f} PLCC={:.3f} KROCC={:.3f} RMSE={:.3f}".format(test_loss,
                                                                                                  SROCC,
                                                                                                  PLCC,
                                                                                                  KROCC,
                                                                                                  RMSE))

Final test Results: loss=0.902 SROCC=0.347 PLCC=0.378 KROCC=0.257 RMSE=1.102


In [35]:
y_test

array([3., 4., 3., 3., 2., 5., 5., 3., 4., 5., 5., 5., 4., 4., 5., 1., 4.,
       4., 5., 3., 4., 2., 5., 3., 4., 5., 3., 3., 5., 4., 2., 3., 5., 3.,
       1., 5., 3., 4., 3., 3., 4., 5., 3., 5., 4., 5., 5., 3., 2., 4., 3.,
       5., 2., 5., 4., 4., 4., 4., 1., 5., 2., 5., 3., 4., 5., 4., 5., 3.,
       5., 4., 4., 5., 4., 5., 5., 2., 3., 4., 3., 3., 4., 5., 1., 2., 5.,
       2., 4., 1., 3., 3., 5., 4., 3., 4., 3., 2.])

In [36]:
y_pred

array([3.49649525, 3.8488152 , 3.82291579, 3.40156889, 3.73796296,
       3.86409712, 3.97929382, 3.43160295, 3.77204633, 3.79290915,
       3.8627069 , 3.7951293 , 3.90003753, 3.81461215, 3.77649808,
       3.29390693, 3.85351086, 3.76837111, 3.71456218, 3.73433423,
       3.78590393, 3.73904967, 3.81949449, 3.67965508, 4.13626575,
       3.77891397, 3.45730448, 3.84554648, 3.9262991 , 3.9518013 ,
       3.48791361, 3.58318114, 3.8865788 , 3.7758925 , 3.73253369,
       3.87430072, 3.62002397, 3.73118401, 3.97094512, 3.73111939,
       3.79931974, 3.83701134, 3.55569458, 3.99449873, 3.27101469,
       3.93501806, 3.91533494, 4.08376551, 3.72365689, 3.97489405,
       3.74845576, 3.32286954, 3.27333713, 3.65396762, 3.70929146,
       3.80980587, 3.72893977, 3.68569899, 3.41213322, 3.72953343,
       3.83491611, 3.73140526, 3.66633034, 3.38981581, 3.67772722,
       4.00991774, 3.59634256, 3.83149457, 3.53208137, 3.92337894,
       3.75965118, 3.75458407, 3.76925159, 3.6636138 , 3.52949