In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import glob

import torch
from torch.utils.data import DataLoader, random_split, Subset
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from dataset import spikeData
from preprocess import preprocess, thresholdEvents
from models import AE,VAE
from clustering import *

import argparse
import pdb
import random

seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)


def train(model,max_epoch=50):
    #### Training and validation loop!
    trLoss = []
    trAcc = []
    vlLoss = []
    vlAcc = []
    for epoch in range(max_epoch):          ### Run the model for max_epochs

        epLoss = 0
        for i,val  in enumerate(tqdm(train_loader)): 
            x,xIp,_ = val   ### Fetch a batch of training inputs
            x, xIp = x.to(device), xIp.to(device)
            xHat,z = model(x)               ####### Obtain a prediction from the network
            loss = model.loss_function(xHat,xIp,z)     ######### Compute loss bw prediction and ground truth

            ### Backpropagation steps
            ### Clear old gradients
            optimizer.zero_grad()
            ### Compute the gradients wrt model weights
            loss.backward()
            ### Update the model parameters
            optimizer.step()

            epLoss += loss.item()

        trLoss.append(epLoss/len(train_loader))

        epLoss = 0
        for x, xIp, _ in valid_loader: #### Fetch validation samples
            x, xIp = x.to(device), xIp.to(device)
            xHat,z = model(x) ##########
            loss = model.loss_function(xHat,xIp,z) #######

            epLoss += loss.item()
        vlLoss.append(epLoss/len(valid_loader))

        print('Epoch: %03d, Tr. Loss: %.4f, Vl.Loss: %.4f'
              %(epoch,trLoss[-1],vlLoss[-1]))
    plt.clf()
    plt.plot(trLoss,label='training')
    plt.plot(vlLoss,label='validation')
    plt.legend()
    plt.tight_layout()
    plt.savefig('loss_curve.pdf', dpi=300)
    return model

def train_VAE(model,max_epoch=50):
    #### Training and validation loop!
    trLoss = []
    trAcc = []
    vlLoss = []
    vlAcc = []
    
     # scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20, verbose=True)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-5, last_epoch=-1)
    iters = len(train_loader)
    for epoch in range(max_epoch):          ### Run the model for max_epochs

        epLoss = 0
        for i,val  in enumerate(tqdm(train_loader)): 
            x,xIp,_ = val   ### Fetch a batch of training inputs
            x, xIp = x.to(device), xIp.to(device)
            xHat, mean, log_var = model(x)               ####### Obtain a prediction from the network
            loss = model.loss_function(xHat,xIp,mean,log_var)    ######### Compute loss bw prediction and ground truth

            ### Backpropagation steps
            ### Clear old gradients
            optimizer.zero_grad()
            ### Compute the gradients wrt model weights
            loss.backward()
            ### Update the model parameters
            optimizer.step()
            scheduler.step(epoch + i/iters)
            epLoss += loss.item()
            

        trLoss.append(epLoss/len(train_loader))

        epLoss = 0
        for x, xIp, _ in valid_loader: #### Fetch validation samples
            x, xIp = x.to(device), xIp.to(device)
            xHat, mean, log_var = model(x)               ####### Obtain a prediction from the network
            loss = model.loss_function(xHat,xIp,mean,log_var)    ######### Compute loss bw prediction and ground truth

            epLoss += loss.item()
        
        val_loss = epLoss/len(valid_loader)
        
        
        vlLoss.append(val_loss)

        print('Epoch: %03d, Tr. Loss: %.6f, Vl.Loss: %.6f'
              %(epoch,trLoss[-1],vlLoss[-1]))
    plt.clf()
    plt.plot(trLoss,label='training')
    plt.plot(vlLoss,label='validation')
    plt.legend()
    plt.tight_layout()
    plt.savefig('loss_curve_VAEl1loss0529.pdf', dpi=300)
    return model


### Main starts here
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)


  from pandas.core.computation.check import NUMEXPR_INSTALLED


cuda


In [2]:
# data = 'foundation_data_all.pt'  # Spike data
# output = None  # Output filename
# thresh = 0.7  # Event detection threshold
# cutoff = 0.5  # Base event detection threshold
# poff = 12  # Positive offset after peak events
# noff = 8  # Negative offset after peak events
# corr = 0.0  # Correlation threshold for merging templates
# mask = False  # Pretext task of masking input
# kmeans = False  # Use kmeans for final clustering
# jigsaw = False  # Pretext task of sorting jigsawed input
# shuffle = False  # Pretext task of sorting shuffled input

# epochs = 20  # Number of epochs
# hidden = 2048  # Number of hidden units at input in AE
# latent = 2  # Number of latent dimensions at bottleneck
# batch = 1024  # Training batch size
# lr = 0.0001  # Learning rate

# ip_dim = 20  # Number of input features
# save = True  # Save the trained model
# denoise = False  # Denoise the data
# reprocess = False  # Reprocess data
# retrain = False  # Retrain model every threshold
# model = None  # Pretrained model file
class Args:
    def __init__(self):
        self.data = r"E:\Work\SpikeSorting\kiehnlab\spikedate_pytorch\foundation_data_all.pt"
        self.output = None
        self.thresh = 0.7
        self.cutoff = 0.5
        self.poff = 12
        self.noff = 8
        self.corr = 0.0
        self.mask = False
        self.kmeans = False
        self.jigsaw = False
        self.shuffle = True
        self.epochs = 50
        self.hidden = 2048
        self.latent = 8
        self.batch = 4096
        self.lr = 0.0002
        self.ip_dim = 20
        self.save = True
        self.denoise = False
        self.reprocess = False
        self.retrain = False
        self.model = None

args = Args()
print(f'Loading and preprocessing {args.data.split("/")[-1]}')

Loading and preprocessing E:\Work\SpikeSorting\kiehnlab\spikedate_pytorch\foundation_data_all.pt


In [3]:
model_name = args.model
if model_name == None:
    model_name = 'models/model_H_'+repr(args.hidden)+'L_'+repr(args.latent)+'_VAEMSE_0524' 
if args.shuffle:
    model_name += '_shuffle'
elif args.mask:
    model_name += '_mask'
elif args.jigsaw:
    model_name += '_jigsaw'
print('Saving model as '+model_name)

### Obtain a distribution of number of events for different thresholds
num_events = []
thresh = args.thresh

### Instantiate a model! 
nIp = args.ip_dim 
model = VAE(nIp=nIp,nhid=args.hidden,latent_dim=args.latent) 
model = model.to(device)
criterion = nn.L1Loss() ############ Loss function to be optimized. 
optimizer = torch.optim.Adam(model.parameters(),lr=args.lr) 
#pdb.set_trace()
### Load pretrained model

if args.retrain:
### Train the neural network!
    print("Training on every datafile...")
elif (args.model is not None):
    model.load_state_dict(torch.load(args.model,map_location=device))
    print("Using pretrained model...")
    model_name = args.model[:-9]

events = (torch.load(args.data)).squeeze()
events = events[[(events.max(-1)[0] <= 1.0)]].unsqueeze(1).unsqueeze(1)
posEventIdx = thresholdEvents(events,thresh)
N = len(events[posEventIdx])
print("Found %d events at %.4f threshold"%(len(events[posEventIdx]),thresh))

tmp = 0
    
print("################## Using Threshold=%.2f ##############"%thresh)
### Make torch dataset
#### Make training, validation and test sets
evClsLabel = np.zeros(len(events),dtype=int)
eventSeq = np.arange(len(events),dtype=int)
dataset = spikeData(data=events, evMask=evClsLabel,mask=args.mask,event_index=eventSeq,\
    jigsaw=args.jigsaw,shuffle=args.shuffle,thresh=thresh)


nTrain = int(0.8*N)
nValid = N - nTrain
train_set = Subset(dataset, list(range(nTrain))) 
valid_set = Subset(dataset, list(range(nTrain,N))) #random_split(dataset,[nTrain, nValid])
B = args.batch

### Wrapping the datasets with DataLoader class 
train_loader = DataLoader(train_set,batch_size=B, shuffle=True)
valid_loader = DataLoader(valid_set,batch_size=B, shuffle=False)
print("Ntrain: %d, NValid: %d"%(nTrain,nValid))

### Train the neural network!
print("Training on every datafile...")
model = model.to(device)
model = train_VAE(model,args.epochs)
    
### Save the trained model
mName = model_name+'.pt'
print('Saving model '+mName)
torch.save(model.state_dict(), mName)

Saving model as models/model_H_2048L_8_VAEMSE_0524_shuffle
Found 4826366 events at 0.7000 threshold
################## Using Threshold=0.70 ##############
Using shuffle pretext
Found 4826366 events at 0.70 threshold
Ntrain: 3861092, NValid: 965274
Training on every datafile...


100%|██████████| 943/943 [02:32<00:00,  6.17it/s]


Epoch: 000, Tr. Loss: 0.080343, Vl.Loss: 0.049042


100%|██████████| 943/943 [02:31<00:00,  6.20it/s]


Epoch: 001, Tr. Loss: 0.051211, Vl.Loss: 0.049063


100%|██████████| 943/943 [02:33<00:00,  6.13it/s]


Epoch: 002, Tr. Loss: 0.050866, Vl.Loss: 0.048989


100%|██████████| 943/943 [03:45<00:00,  4.19it/s]


Epoch: 003, Tr. Loss: 0.050752, Vl.Loss: 0.049004


100%|██████████| 943/943 [02:18<00:00,  6.81it/s]


Epoch: 004, Tr. Loss: 0.050694, Vl.Loss: 0.048982


100%|██████████| 943/943 [02:26<00:00,  6.42it/s]


Epoch: 005, Tr. Loss: 0.050564, Vl.Loss: 0.048965


100%|██████████| 943/943 [02:22<00:00,  6.64it/s]


Epoch: 006, Tr. Loss: 0.050516, Vl.Loss: 0.048906


100%|██████████| 943/943 [03:39<00:00,  4.30it/s]


Epoch: 007, Tr. Loss: 0.050465, Vl.Loss: 0.048907


 21%|██        | 196/943 [00:34<02:11,  5.67it/s]


KeyboardInterrupt: 