<a href="https://colab.research.google.com/github/stepanbabayan/DFBS-Object-Classification/blob/colab/test_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Clone Repository

In [None]:
!git clone https://github.com/stepanbabayan/DFBS-Object-Classification.git

## Switch to colab branch 

In [None]:
mv ./DFBS-Object-Classification/ ./Model/

In [None]:
cd Model

In [None]:
!git checkout colab

In [None]:
# !git pull

## Additional Environment Setups

In [None]:
import zipfile
with zipfile.ZipFile('./data.zip', 'r') as zip_ref:
    zip_ref.extractall('')

In [None]:
import sys
sys.path.append('Model/')

## Imports

In [None]:
import os
import shutil

import torch.optim
from torchsummary import summary

import load_data
import models
from Model.test import evaluate
from _helpers import make_directory

from sklearn.metrics import classification_report

## Environment variables

In [None]:
use_gpu = True

In [None]:
# Training Device
if use_gpu:
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print('Device: GPU')
else:
    device = torch.device('cpu')
    print('Device: CPU')

## Data path

In [None]:
# Choose the dataset
num_classes = 5 # Choices: {5, 6, 10}

In [None]:
classes_5 = ['C-H', 'C-N', 'Mrk SB', 'sdA', 'sdB']
classes_6 = ['C-H', 'C-N', 'Mrk Abs', 'Mrk SB', 'sdA', 'sdB']
classes_10 = ['C Ba', 'C-H', 'C-N', 'C-R', 'Mrk Abs', 'Mrk AGN', 'Mrk SB', 'sdA', 'sdB', 'sdO']

assert num_classes in {5, 6, 10}

if num_classes == 10:
    classes = classes_10
    checkpoint_name = 'Dense_10_Focal_25_3_Final/139.pth'
elif num_classes == 5:
    classes = classes_5
    checkpoint_name = 'Dense_5_High_Focal_25_3_Final/59.pth'
else:
    classes = classes_6
    checkpoint_name = 'Dense_6_High_Focal_25_3_Final/136.pth'


In [None]:
# Datasets
data_root = f'./data'

if num_classes != 10:
    new_data_root = f'./data_{num_classes}'
    domains = ['test', 'train']

    for domain in domains:
        make_directory(f'{new_data_root}/{domain}')

        for class_name in classes:
            # make_directory(f'{new_data_root}/{domain}/{class_name}')
            shutil.copytree(f'{data_root}/{domain}/{class_name}/', f'{new_data_root}/{domain}/{class_name}/')

    data_root = new_data_root

In [None]:
test_dir = f'{data_root}/test'
# Optional
train_dir = f'{data_root}/train'

input_shape = (160, 50)

In [None]:
print('Num classes:', num_classes)

## Project Parameters

In [None]:
root_dir = os.path.abspath('./')

In [None]:
# Checkpoints are saved in Checkpoint folder
checkpoint_path = f'{root_dir}/Checkpoint/{checkpoint_name}'

## Testing Parameters

In [None]:
# Batch sizes
test_batch_size = 16
# Optional
train_batch_size = 16

## Data Loaders

In [None]:
test_data, test_classes, _ = load_data.load_images(test_dir, test_batch_size, 'test', _drop_last=False)
# Optional
train_data, _, _ = load_data.load_images(train_dir, train_batch_size, 'train', _drop_last=False)

## Training Setup

In [None]:
# Model choices: arch = any(['default', 'default_prev', 'default_bn', 'mobilenet', 'resnet'])
#   default: the proposed network
#   default_bn: similar to the proposed, but with more BatchNorm layers
#   default_prev: the network proposed in the previous work
#   mobilenet: MobileNetV2
#   resnet: Resnet

net = models.Model(num_classes=num_classes, input_shape=input_shape, arch='default').to(device)

### Layers

In [None]:
print(net)

### Output Summary

In [None]:
summary(net, (1, 160, 50))

In [None]:
# Setting the network up for evaluation
net.load_state_dict(torch.load(checkpoint_path))
net.eval()

In [None]:
print('\nEvaluation started:')

train_score = evaluate(dataloader=train_data, model=net, domain='train', device=device)
test_score = evaluate(dataloader=test_data, model=net, device=device, classes=test_classes)