<a href="https://colab.research.google.com/github/YolandaMDavis/NSSADNN_IQA/blob/wildtrack-iqa/wildtrack_multitask_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive') # for google colab. adjust accordingly
PARENT_DIR = '/content/drive/MyDrive/Wildtrack Group/IQA' 

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import shutil
from zipfile import ZipFile

# copy and extract tar file
shutil.copy(PARENT_DIR + '/data/WildTrack_Raw.zip', 'WildTrack_Raw.zip')

with ZipFile('WildTrack_Raw.zip', 'r') as zipObj:
   # Extract all the contents of zip file in current directory
   zipObj.extractall()

# copy supporting project files
shutil.copy(PARENT_DIR + '/multitask_model/WildTrackDataset.py', 'WildTrackDataset.py')
shutil.copy(PARENT_DIR + '/multitask_model/IQADataset.py', 'IQADataset.py')
shutil.copy(PARENT_DIR + '/multitask_model/network.py', 'network.py')
shutil.copy(PARENT_DIR + '/multitask_model/brisque.py', 'brisque.py')
shutil.copy(PARENT_DIR + '/multitask_model/config.yaml', 'config.yaml')

'config.yaml'

In [3]:
# director variables
root_dir = '/content'
data_dir = root_dir + '/RAW'
image_reference_file_suffix = '_image_references.csv'

In [4]:
import os
import csv
import random

import torch
import yaml

from torch.utils.data import Dataset

def generate_data_files():

    image_reference_list = []

    subdirectories = list(os.walk(data_dir, topdown=False))[:-1]
    for subdir in subdirectories:
        image_location = subdir[0]
        images = subdir[2]
        species_rating = image_location.rsplit('/', 1)[-1].replace('_', ' ')
        score = int(species_rating.rsplit(' ', 1)[-1])
        species_class = species_rating.rsplit(' ', 1)[:-1][0]
        if len(species_class.rsplit(' ', 1)) > 1:
            species = species_class.rsplit(' ')[0]
            animal_class = ' '.join(species_class.rsplit(' ')[1:])
        else:
            animal_class = 'Unknown'
            species = species_class

        for image in images:
            image_reference = (image_location, species, animal_class, image, score)
            image_reference_list.append(image_reference)

    # shuffle then split
    seed = 1234
    random.Random(seed).shuffle(image_reference_list)
    training = image_reference_list[:int(len(image_reference_list) * 0.06)]
    validation = image_reference_list[-int(len(image_reference_list) * 0.02):]
    testing = image_reference_list[-int(len(image_reference_list) * 0.02):]

    # generated reference splits
    for dataset in [('training', training), ('validation', validation), ('testing', testing)]:
        ref_file_name = root_dir + '/' + dataset[0] + image_reference_file_suffix
        with open(ref_file_name, 'w', newline='') as csvfile:
            image_ref_writer = csv.writer(csvfile, delimiter=',', quotechar='"', quoting=csv.QUOTE_NONNUMERIC)
            image_ref_writer.writerows(dataset[1])

generate_data_files()

In [5]:
import numpy as np
from scipy import stats
import torch.nn as nn
import random
from network import NSSADNN
from WildTrackDataset import WildTrackDataset
import yaml

batch_size =8
epochs = 3
lr = 0.001
weight_decay=0.0001

save_model = root_dir + "/model.pth"

seed = random.randint(10000000, 99999999)

torch.manual_seed(seed)
np.random.seed(seed)
print("seed:", seed)

with open(root_dir + "/config.yaml") as f:
    config = yaml.load(f)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dataset = WildTrackDataset(root_dir + '/' + 'training' + image_reference_file_suffix, config, "train")
train_loader = torch.utils.data.DataLoader(train_dataset,
                                            batch_size=batch_size,
                                            shuffle=True,
                                            pin_memory=True,
                                            num_workers=0)


val_dataset = WildTrackDataset(root_dir + '/' + 'validation' + image_reference_file_suffix, config, "validation")
val_loader = torch.utils.data.DataLoader(val_dataset)
valnum = val_dataset.row_count

test_dataset = WildTrackDataset(root_dir + '/' + 'testing' + image_reference_file_suffix, config, "testing")
test_loader = torch.utils.data.DataLoader(test_dataset)
testnum = test_dataset.row_count


seed: 85208021
Processing file number:0
Patch length is: 1302
Processing file number:1
Patch length is: 1302
Processing file number:2
Patch length is: 20584
Processing file number:3
Patch length is: 1302
Processing file number:4
Patch length is: 1302
Processing file number:5
Patch length is: 1302
Processing file number:6
Patch length is: 1302
Processing file number:7
Patch length is: 21888
Processing file number:8
Patch length is: 1302
Processing file number:9
Patch length is: 1302
Processing file number:10
Patch length is: 1302
Processing file number:11
Patch length is: 1302
Processing file number:12
Patch length is: 1302
Processing file number:13
Patch length is: 21888
Processing file number:14
Patch length is: 1302
Processing file number:15
Patch length is: 1302
Processing file number:16
Patch length is: 1302
Processing file number:17
Patch length is: 10375
Processing file number:18
Patch length is: 1302
Processing file number:19
Patch length is: 1302
Processing file number:20
Patch

In [6]:
model = NSSADNN().to(device)
criterion = nn.L1Loss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay, momentum=0.9)
torch.optim.lr_scheduler.StepLR(optimizer, 750, gamma=0.1, last_epoch=-1)
best_SROCC = -1

# training 
for epoch in range(epochs):
      # train
      model.train()
      LOSS_all = 0
      LOSS_NSS = 0
      LOSS_q = 0
      for i, (patches, (label, features)) in enumerate(train_loader):
          patches = patches.to(device)
          label = label.to(device)
          features = features.to(device).float()

          optimizer.zero_grad()
          outputs_q = model(patches)[0]
          outputs_NSS = model(patches)[1]

          loss_NSS = criterion(outputs_NSS, features)
          loss_q = criterion(outputs_q, label)
          loss = loss_NSS + loss_q

          loss.backward()
          optimizer.step()
          LOSS_all += float(loss.item())
          LOSS_NSS += float(loss_NSS.item())
          LOSS_q += float(loss_q.item())
      train_loss_all = LOSS_all / (i + 1)
      train_loss_NSS = LOSS_NSS / (i + 1)
      train_loss_q = LOSS_q / (i + 1)

      # val
      y_pred = np.zeros(valnum)
      y_val = np.zeros(valnum)
      model.eval()
      L = 0
      with torch.no_grad():
          for i, (patches, (label, features)) in enumerate(val_loader):
              y_val[i] = label.item()
              patches = patches.to(device)
              label = label.to(device)
              outputs_q = model(patches)[0]
              score = outputs_q.mean()
              y_pred[i] = score
              loss = criterion(score, label[0])
              L = L + loss.item()
      val_loss = L / (i + 1)

      val_SROCC = stats.spearmanr(y_pred, y_val)[0]
      val_PLCC = stats.pearsonr(y_pred, y_val)[0]
      val_KROCC = stats.stats.kendalltau(y_pred, y_val)[0]
      val_RMSE = np.sqrt(((y_pred - y_val) ** 2).mean())

      # test
      y_pred = np.zeros(testnum)
      y_test = np.zeros(testnum)
      L = 0
      with torch.no_grad():
          for i, (patches, (label, features)) in enumerate(test_loader):
              y_test[i] = label.item()
              patches = patches.to(device)
              label = label.to(device)
              outputs_q = model(patches)[0]

              score = outputs_q.mean()
              y_pred[i] = score
              loss = criterion(score, label[0])
              L = L + loss.item()
      test_loss = L / (i + 1)
      SROCC = stats.spearmanr(y_pred, y_test)[0]
      PLCC = stats.pearsonr(y_pred, y_test)[0]
      KROCC = stats.stats.kendalltau(y_pred, y_test)[0]
      RMSE = np.sqrt(((y_pred - y_test) ** 2).mean())

      print("Epoch {} Valid Results: loss={:.3f} SROCC={:.3f} PLCC={:.3f} KROCC={:.3f} RMSE={:.3f}".format(epoch,
                                                                                                            val_loss,
                                                                                                            val_SROCC,
                                                                                                            val_PLCC,
                                                                                                            val_KROCC,
                                                                                                            val_RMSE))
      print("Epoch {} Test Results: loss={:.3f} SROCC={:.3f} PLCC={:.3f} KROCC={:.3f} RMSE={:.3f}".format(epoch,
                                                                                                          test_loss,
                                                                                                          SROCC,
                                                                                                          PLCC,
                                                                                                          KROCC,
                                                                                                          RMSE))

      if val_SROCC > best_SROCC and epoch > 100:
          print("Update Epoch {} best valid SROCC".format(epoch))
          print("Valid Results: loss={:.3f} SROCC={:.3f} PLCC={:.3f} KROCC={:.3f} RMSE={:.3f}".format(val_loss,
                                                                                                      val_SROCC,
                                                                                                      val_PLCC,
                                                                                                      val_KROCC,
                                                                                                      val_RMSE))
          print("Test Results: loss={:.3f} SROCC={:.3f} PLCC={:.3f} KROCC={:.3f} RMSE={:.3f}".format(test_loss,
                                                                                                      SROCC,
                                                                                                      PLCC,
                                                                                                      KROCC,
                                                                                                      RMSE))
          torch.save(model.state_dict(), save_model)
          best_SROCC = val_SROCC


  return F.l1_loss(input, target, reduction=self.reduction)


Epoch 0 Valid Results: loss=0.792 SROCC=-0.159 PLCC=-0.158 KROCC=-0.146 RMSE=1.117
Epoch 0 Test Results: loss=0.792 SROCC=-0.159 PLCC=-0.158 KROCC=-0.146 RMSE=1.117
Epoch 1 Valid Results: loss=0.792 SROCC=0.232 PLCC=0.172 KROCC=0.206 RMSE=1.118
Epoch 1 Test Results: loss=0.792 SROCC=0.232 PLCC=0.172 KROCC=0.206 RMSE=1.118
Epoch 2 Valid Results: loss=0.792 SROCC=-0.032 PLCC=-0.036 KROCC=-0.029 RMSE=1.118
Epoch 2 Test Results: loss=0.792 SROCC=-0.032 PLCC=-0.036 KROCC=-0.029 RMSE=1.118


In [9]:
# final test
torch.save(model.state_dict(), save_model)

model.load_state_dict(torch.load(save_model))
model.eval()
with torch.no_grad():
    y_pred = np.zeros(testnum)
    y_test = np.zeros(testnum)
    L = 0
    for i, (patches, (label, features)) in enumerate(test_loader):
        y_test[i] = label.item()
        patches = patches.to(device)
        label = label.to(device)

        outputs = model(patches)[0]
        score = outputs.mean()

        y_pred[i] = score
        loss = criterion(score, label[0])
        L = L + loss.item()
test_loss = L / (i + 1)
SROCC = stats.spearmanr(y_pred, y_test)[0]
PLCC = stats.pearsonr(y_pred, y_test)[0]
KROCC = stats.stats.kendalltau(y_pred, y_test)[0]
RMSE = np.sqrt(((y_pred - y_test) ** 2).mean())

print("Final test Results: loss={:.3f} SROCC={:.3f} PLCC={:.3f} KROCC={:.3f} RMSE={:.3f}".format(test_loss,
                                                                                                  SROCC,
                                                                                                  PLCC,
                                                                                                  KROCC,
                                                                                                  RMSE))

Final test Results: loss=0.792 SROCC=-0.032 PLCC=-0.036 KROCC=-0.029 RMSE=1.118


In [12]:
y_test

3.5625

In [11]:
y_pred

array([4.00118494, 4.00118494, 4.00118494, 4.00118494, 4.00118494,
       4.00118494, 4.00118494, 4.00118494, 4.00118494, 4.00118494,
       4.00118494, 4.00118494, 4.00118494, 4.00118494, 4.00118494,
       4.00118494, 4.00118494, 4.00118494, 4.00118494, 4.00118494,
       4.00118494, 4.00118494, 4.00118494, 4.00118494, 4.00118494,
       4.00118494, 4.00118494, 4.00118494, 4.00118494, 4.00118494,
       4.00118494, 4.00118494, 4.00118446, 4.00118494, 4.00118494,
       4.00118494, 4.00118494, 4.00118494, 4.00118542, 4.00118494,
       4.00118494, 4.00118494, 4.00118494, 4.00118494, 4.00118494,
       4.00118494, 4.00118494, 4.00118494, 4.00118446, 4.00118494,
       4.00118494, 4.00118446, 4.00118494, 4.00118494, 4.00118494,
       4.00118446, 4.00118494, 4.00118494, 4.00118494, 4.00118494,
       4.00118494, 4.00118494, 4.00118494, 4.00118494, 4.00118494,
       4.00118494, 4.00118494, 4.00118494, 4.00118494, 4.00118494,
       4.00118494, 4.00118494, 4.00118494, 4.00118494, 4.00118