In [57]:
import models
import torch
import numpy as np
import importlib 
from torch_geometric.data import Data, DataLoader
import tqdm
import pandas as pd
import torch.nn as nn

In [58]:

# Define a Dictionary that contains the models for each behavior
MODELS = {'General_Contact': [models.GATEncoder(nout = 64, nhid=32, attention_heads = 2, n_in = 4, n_layers=4, dropout=0.2), models.ClassificationHead(n_latent=64, nhid = 32, nout = 2), 'mean'],
        'Sniffing': [models.GATEncoder(nout = 64, nhid=32, attention_heads = 2, n_in = 4, n_layers=4, dropout=0.2), models.ClassificationHead(n_latent=64, nhid = 32, nout = 2), 'mean'],
        'Sniffing_head': [models.GATEncoder(nout = 64, nhid=32, attention_heads = 2, n_in = 4, n_layers=4, dropout=0.2), models.ClassificationHead(n_latent=64, nhid = 32, nout = 2), 'mean'],
        'Sniffing_other': [models.GATEncoder(nout = 64, nhid=32, attention_heads = 2, n_in = 4, n_layers=4, dropout=0.2), models.ClassificationHead(n_latent=64, nhid = 32, nout = 2), 'mean'],
        'Sniffing_anal': [models.GATEncoder(nout = 64, nhid=32, attention_heads = 2, n_in = 4, n_layers=4, dropout=0.2), models.ClassificationHead(n_latent=64, nhid = 32, nout = 2), 'mean'],
        'Poursuit': [models.GATEncoder(nout = 64, nhid=32, attention_heads = 2, n_in = 4, n_layers=4, dropout=0.2), models.ClassificationHead(n_latent=64, nhid = 32, nout = 2), 'mean'],
        'Dominance': [models.GATEncoder(nout = 64, nhid=32, attention_heads = 2, n_in = 4, n_layers=4, dropout=0.2), models.ClassificationHead(n_latent=64, nhid = 32, nout = 2), 'mean'],
        'Rearing': [models.GATEncoder(nout = 64, nhid=32, attention_heads = 2, n_in = 4, n_layers=4, dropout=0.2), models.ClassificationHead(n_latent=64, nhid = 32, nout = 2), 'mean'],
        'Grooming': [models.GATEncoder(nout = 64, nhid=32, attention_heads = 2, n_in = 4, n_layers=4, dropout=0.2), models.ClassificationHead(n_latent=64, nhid = 32, nout = 2), 'mean']
        }


### Function that returns the model based on the behavior
def get_model(behavior) -> nn.Module:
    ''' Returns the model based on the behavior.
        Possible behaviors: 'General_Contact', 'Sniffing', 'Sniffing_head', 'Sniffing_other', 'Sniffing_anal', 'Poursuit', 'Dominance', 'Rearing', 'Grooming'
    Parameters:
        - behavior: str, the behavior of the model
    Returns:
        - model: nn.Module, the model
    '''
    gatencoder = MODELS[behavior][0]
    classifier = MODELS[behavior][1]
    readout = MODELS[behavior][2]

    return models.GraphClassifier(encoder=gatencoder, classifier=classifier, readout=readout)

In [59]:
importlib.reload(models)

<module 'models' from 'c:\\Users\\jalvarez\\Documents\\Code\\GitHubCOde\\Behavioral_Tagging_of_Mice_in_multiple_Mice_dataset_using_Deep_Learning\\src\\models.py'>

In [60]:
def load_model(model_path, device, behaviour = 'General_Contact'):
    ''' This function loads a model from a given path and returns it.
    Args:
        model_path: path to the model
        device: device on which the model should be loaded
        behaviour: behaviour of the model
    Returns:
        model: the loaded model
    '''
    model = get_model(behaviour) # get the model
    checkpoint = torch.load(model_path, map_location=device) # load the model
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device) # send the model to the device
    model.eval() # set the model to evaluation mode
    return model

In [80]:
model_path = r'd:\Backup_mantenimiento_ruche\Data\Checkpoints\new_encoder_no_linearResCon\General_Contacts\checkpoint_epoch_310.pth'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [81]:
model = load_model(model_path, device)

In [82]:
# Load the data
data_path = r'c:\Users\jalvarez\Documents\Data\DataLoadaerTESTTSTST\dataset_test.pkl'

In [86]:
# Analyze the data
data = torch.load(data_path)

In [87]:
# Get the behaviour
test_loader = DataLoader(data, batch_size=1, shuffle=False)

In [85]:
a = next(iter(test_loader))
a.frame_mask.median().item()

2

In [88]:
output = np.zeros((len(test_loader),2))
softmax = torch.nn.Softmax(dim=1)
for i, data in enumerate(tqdm.tqdm(test_loader)):
    data = data.to(device)
    output[i][0] = int(data.frame_mask.median().item())
    with torch.no_grad():
        out = model(data)
        out = softmax(out)
        output[i][1] = out.argmax().item()

    

100%|██████████| 2627/2627 [00:29<00:00, 89.35it/s] 


In [89]:
def create_csv_with_output_behaviour(output, behaviour, path):
    ''' This function creates a csv file with the output of the model for each frame.
    Args:
        output: the output of the model
        behaviour: the behaviour analyzed
        path: the path where the csv file should be saved
    '''
    df = pd.DataFrame(output, columns = ['frame', behaviour])
    df.to_csv(path, index = False)

In [90]:
path_to_save = r'c:\Users\jalvarez\Documents\Data\DataLoadaerTESTTSTST\output.csv'
create_csv_with_output_behaviour(output, 'General_Contact', path_to_save)

In [92]:
# ground truth file
path_to_ground_truth = r'c:\Users\jalvarez\Documents\Data\DataLoadaerTESTTSTST\GT\DMD_fem_Test_1.csv'
ground_truth = pd.read_csv(path_to_ground_truth)

# Get accuracy
def get_accuracy(ground_truth, output):
    ''' This function returns the accuracy of the model.
    Args:
        ground_truth: the ground truth
        output: the output of the model
    Returns:
        accuracy: the accuracy of the model
    '''
    accuracy = (ground_truth.iloc[output[:,0].astype(int)]['General_Contacts'] == output[:,1]).sum()/len(output)
    return accuracy

get_accuracy(ground_truth, output)

0.677198325085649

In [2]:
# Load the model
def load_checkpoint(model, optimizer, path, device):
    checkpoint = torch.load(path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    print(f"Checkpoint loaded from {path}, at epoch {epoch}")
    return model, optimizer, epoch

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# Load the data
dataset = torch.load(r'c:\Users\jalvarez\Documents\Data\LargeDataset\entire_dataset.pkl', map_location=device)



cpu
