In [None]:
from preprocessing import Preprocessor
from training import Trainer
from metrics import Metrics
import torch.nn as nn
import torch.optim as optim
from models.vgg_TL import GoogleNet, VGG, ResNet
from models.autoencoders import Simple_AE
from configuration import Hyperparameters as HP
import torch
import json
import math

years = [str(y) for y in range(2006, 2015)]

classes = ["detritus", "Leptocylindrus", "Chaetoceros", "Rhizosolenia", "Guinardia_delicatula", "Cerataulina", "Cylindrotheca",
    "Skeletonema", "Dactyliosolen", "Thalassiosira", "Dinobryon", "Corethron", "Thalassionema", "Ditylum", "pennate", "Prorocentrum",
    "Pseudonitzschia", "Tintinnid", "Guinardia_striata", "Phaeocystis"]

all_classes = ["mix", "detritus", "Leptocylindrus", "mix_elongated", "Chaetoceros", "dino30", "Rhizosolenia", "Guinardia_delicatula", 
"Cerataulina", "Cylindrotheca", "Skeletonema", "Ciliate_mix", "Dactyliosolen", "Thalassiosira", "bad", "Dinobryon", "Corethron", 
"DactFragCerataul", "Thalassionema", "Ditylum", "pennate", "Prorocentrum", "Pseudonitzschia", "Mesodinium_sp", "G_delicatula_parasite", 
"Tintinnid", "Guinardia_striata", "Phaeocystis", "Dictyocha", "Pleurosigma", "Eucampia", "Thalassiosira_dirty", 
"Asterionellopsis", "flagellate_sp3", "Laboea_strobila", "Chaetoceros_didymus_flagellate", "Heterocapsa_triquetra", "Guinardia_flaccida", 
"Chaetoceros_pennate", "Ceratium", "Euglena", "Coscinodiscus", "Strombidium_morphotype1", "Paralia", "Gyrodinium", "Ephemera", "Pyramimonas_longicauda", 
"Proterythropsis_sp", "Gonyaulax", "kiteflagellates", "Chrysochromulina", "Chaetoceros_didymus", "bead", "Katodinium_or_Torodinium", "Leptocylindrus_mediterraneus", 
"spore", "Tontonia_gracillima", "Delphineis", "Dinophysis", "Strombidium_morphotype2", "Licmophora", "Lauderia", "clusterflagellate", "Strobilidium_morphotype1", 
"Leegaardiella_ovalis", "pennate_morphotype1", "amoeba", "Strombidium_inclinatum", "Pseudochattonella_farcimen", "Amphidinium_sp", "dino_large1", 
"Strombidium_wulffi", "Chaetoceros_flagellate", "Strombidium_oculatum", "Cerataulina_flagellate", "Emiliania_huxleyi", "Pleuronema_sp", "Strombidium_conicum",
 "Odontella", "Protoperidinium", "zooplankton", "Stephanopyxis", "Tontonia_appendiculariformis", "Strombidium_capitatum", "Bidulphia", "Euplotes_sp", 
 "Parvicorbicula_socialis", "bubble", "Hemiaulus", "Didinium_sp", "pollen", "Tiarina_fusus", "Bacillaria", "Cochlodinium", "Akashiwo", "Karenia"]

classes_30 = ["Asterionellopsis", "bad", "Chaetoceros", "Chaetoceros_flagellate", "Ciliate_mix", "Corethron", "Cylindrotheca", "Dictyocha","dino30", "detritus",
	"Dinobryon", "Ditylum", "Eucampia", "flagellate_sp3", "Guinardia_delicatula", "Guinardia_flaccida", "Guinardia_striata", "Heterocapsa_triquetra", "Laboea_strobila", "Leptocylindrus",
	"pennate", "Phaeocystis", "Pleurosigma", "Prorocentrum", "Pseudonitzschia", "Skeletonema", "Thalassionema", "Thalassiosira", "Thalassiosira_dirty", "Tintinnid"]

print(len(classes_30))


In [None]:
#pp = Preprocessor(years, include_classes=classes, train_eg_per_class=HP.number_of_images_per_class)
#pp = Preprocessor(years, include_classes=all_classes, train_eg_per_class=HP.number_of_images_per_class, thresholding=HP.thresholding)
pp = Preprocessor(years, include_classes=classes_30, strategy = HP.pp_strategy, train_eg_per_class=HP.number_of_images_per_class, maxN = HP.maxN, minimum = HP.minimum, transformations = HP.transformations)


pp.create_datasets([0.6,0.2,0.2])

trainLoader = pp.get_loaders('train', HP.batch_size)
validLoader = pp.get_loaders('validation', HP.batch_size)
testLoader = pp.get_loaders('test', HP.batch_size)

In [None]:
def load_sd(model, path_to_statedict):
    state_dict = torch.load(path_to_statedict, map_location=lambda storage, loc: storage)
    model.load_state_dict(state_dict)
    model.eval()
    return model

model_gn = GoogleNet()
model_vgg = VGG()
model_resnet = ResNet()

model_gn = load_sd(model_gn, "models/GoogleNet_1.2-4.3.pth")
model_vgg = load_sd(model_vgg, "models/VGG_2.0-4.3.pth")
model_resnet = load_sd(model_resnet, "models/ResNet_2.0-4.3.pth")

In [None]:
def test(model_gn, model_vgg, model_resnet ,testloader):
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    all_preds = torch.LongTensor().to(device)
    all_targets = torch.LongTensor().to(device)

    all_fnames = []
    model_gn.to(device)
    model_vgg.to(device)
    model_resnet.to(device)

    with torch.no_grad():

        for data in testloader:
            inputs, labels = data['image'].to(device).float(), data['encoded_label'].to(device).float()

            _, labels = torch.max(labels, 1)

            out_gn = model_gn(inputs)
            out_vgg = model_vgg(inputs)
            out_resnet = model_resnet(inputs)
            
            sums = torch.add(torch.add(out_gn, out_vgg), out_resnet)
            outputs = torch.div(sums, sums.shape[1])

            _, predicted = torch.max(outputs.data, 1)
            #print("labels:", labels)
            #print("predicted:", predicted)
            #print("~~~~~~~~~~~~~~~~")
            all_preds = torch.cat((all_preds, predicted), 0)
            all_targets = torch.cat((all_targets, labels), 0) 

            #if total >=10:
            #   break

            all_fnames.extend(data['fname'])

    return all_preds, all_targets, all_fnames

In [None]:
test(model_gn, model_vgg, model_resnet, testLoader)