In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import configparser
import os.path
from os import path
from importlib import reload
import wandb


creds_path_ar = ["../../credentials.ini","credentials.colab.ini"]
PATH_ROOT = ""
PATH_DATA = ""

for creds_path in creds_path_ar:
    if path.exists(creds_path):
        config_parser = configparser.ConfigParser()
        config_parser.read(creds_path)
        PATH_ROOT = config_parser['MAIN']["PATH_ROOT"]
        PATH_DATA = config_parser['MAIN']["PATH_DATA"]
        WANDB_enable = config_parser['MAIN']["WANDB_ENABLE"] == 'TRUE'
        ENV = config_parser['MAIN']["ENV"]
        break

if ENV=="COLAB":
  from google.colab import drive
  mount_path = '/content/gdrive/'
  drive.mount(mount_path)

In [3]:

    wandb.init(project="sota-mafat-base")
    os.environ['WANDB_NOTEBOOK_NAME'] = '[SS]Alexnet_pytorch'

In [4]:
cd {PATH_ROOT}

/home/ubuntu/sota-mafat-radar


In [5]:
import os
import random
import pickle
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim


from sklearn.metrics import roc_auc_score, roc_curve, auc, accuracy_score
from matplotlib.colors import LinearSegmentedColormap
from termcolor import colored

from src.data import feat_data, get_data, get_data_pipeline
from src.models import arch_setup, base_base_model, alex_model, base_3d
from src.features import specto_feat

# Set seed for reproducibility of results
seed_value = 0
os.environ['PYTHONHASHSEED']=str(seed_value)


random.seed(seed_value)
np.random.seed(seed_value)
torch.manual_seed(0)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

if torch.cuda.is_available():
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu:0')

In [35]:
config = dict()
config['num_tracks'] = 3
config['val_ratio'] = 3
config['shift_segment'] = list(np.floor(np.linspace(1,31,10)).astype(int))
config['get_shifts'] = True
config['get_horizontal_flip'] = False
config['get_vertical_flip'] = False
config['wavelets'] = True

batch_size = 32
lr = 1e-4

In [36]:
import importlib
importlib.reload(get_data_pipeline)

<module 'src.data.get_data_pipeline' from '/home/ubuntu/sota-mafat-radar/src/data/get_data_pipeline.py'>

In [37]:
train_x, train_y, val_x, val_y = get_data_pipeline.pipeline_trainval(PATH_DATA, config)

100%|██████████| 51761/51761 [2:56:23<00:00,  4.89it/s]
100%|██████████| 616/616 [02:03<00:00,  4.97it/s]


In [12]:
print(train_x.shape[0])
print(val_x.shape[0])

6261
616


In [38]:
train_set = arch_setup.DS(train_x,train_y)
val_set= arch_setup.DS(val_x,val_y)

train_loader=DataLoader(dataset= train_set, batch_size = batch_size, shuffle = True, num_workers = 2)
val_loader=DataLoader(dataset= val_set, batch_size = batch_size, shuffle = True, num_workers = 2)

In [39]:
model= alex_model.alex_mdf_model()
# model.apply(init_weights)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

model.to(device)

if WANDB_enable == True:
    runname = input("Enter WANDB runname(ENTER to skip wandb) :")
    notes = input("Enter run notes :")

    wandb.init(project="sota-mafat-base",name=runname, notes=notes)
    os.environ['WANDB_NOTEBOOK_NAME'] = '[SS]Alexnet_pytorch'
    
    wandb.watch(model)
    wandb.config['data_config'] = config
    wandb.config['train_size'] = train_x.shape[0]
    wandb.config['val_size'] = val_x.shape[0]
    wandb.config['batch_size'] = batch_size
    wandb.config['learning rate'] = lr
    wandb.log(config)


KeyboardInterrupt: Interrupted by user

In [None]:
model.arch.features[0] = nn.Conv2d(8,64,kernel_size = 11, stride = 4,padding = 2)
model.to(device)

In [None]:
log = arch_setup.train_epochs(train_loader,val_loader,model,criterion,optimizer,num_epochs= 10,device=device,train_y=train_y,val_y=val_y, WANDB_enable = WANDB_enable, wandb= wandb)

In [None]:
arch_setup.plot_loss_train_test(log,model)

In [None]:
arch_setup.plot_ROC_local_gpu(train_loader,val_loader,model,device)

In [31]:
importlib.reload(arch_setup)

<module 'src.models.arch_setup' from '/home/ubuntu/sota-mafat-radar/src/models/arch_setup.py'>