# CORnet-Z

In [None]:
import zipfile
from torch.utils.data import Dataset
from PIL import Image

class ZipDataset(Dataset):
    def __init__(self, zip_path, transform=None):
        self.zip_path = zip_path
        self.transform = transform
        self.samples = []
        with zipfile.ZipFile(self.zip_path, 'r') as zip:
            for file in zip.namelist():
                if file.endswith(('.png', '.jpg', '.jpeg')):
                    class_name = file.split('/')[1]  # Assuming the class name is the folder name
                    self.samples.append((file, int(class_name)))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        file_path, class_index = self.samples[idx]
        with zipfile.ZipFile(self.zip_path, 'r') as zip:
            with zip.open(file_path) as file:
                image = Image.open(file)
                image = image.convert('RGB')
                if self.transform:
                    image = self.transform(image)
                return image, torch.tensor(class_index, dtype=torch.long)

## Train Model

In [None]:
import os
import numpy as np
import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from functools import reduce
from cornet.extend_cornet_z import ExtendedCORnet

seed = 0
pl.seed_everything(seed)

# -------------------------------
# Parameters
# -------------------------------

# ESOS: 5
# Testolin: 10
n_max = 5 # max number of dots

# -------------------------------
# Paths where to load/save data
# -------------------------------

# path where model are saved
model_path = project_path+'/result/model/cornetz_sgd'
os.makedirs(f'{model_path}', exist_ok=True)
# path where log of training are saved
log_path = project_path+'/result/log/train'
os.makedirs(f'{log_path}', exist_ok=True)
# path containing the dataset
train_path = project_path+'/images/ESOS/train'
#train_path = project_path+'/images/Testolin_DeWind/train'
#train_path = project_path+'/images/Testolin_Natural/train'
#train_path = project_path+'/images/Testolin_ISA2/train'

# -------------------------------
# Training dataset
# -------------------------------

# Write transform for image
data_transform = transforms.Compose([
    # Resize the images to 64x64
    transforms.Resize(size=(300, 300)),
    # Flip the images randomly on the horizontal
    transforms.RandomHorizontalFlip(p=0.5), # p = probability of flip, 0.5 = 50% chance
    # Turn the image into a torch.Tensor
    transforms.ToTensor() # this also converts all pixel values from 0 to 255 to be between 0.0 and 1.0
])

#train_data = ZipDataset(train_path, transform=data_transform)
train_data = datasets.ImageFolder(root=train_path, # target folder of images
                                  transform=data_transform, # transforms to perform on data (images)
                                  target_transform=None) # transforms to perform on labels (if necessary)
train_loader = DataLoader(dataset=train_data, batch_size = 32, shuffle = True, num_workers = 4, pin_memory = True)

# -------------------------------
# Initializing model
# -------------------------------

model  = ExtendedCORnet(out_features = n_max, lr = 1e-3, optimizer = 'sgd')

# -------------------------------
# Saving model
# -------------------------------

# saving initial model
torch.save({
	"epoch": -1,
	"global_step": 0,
	"pytorch-lightning_version": pl.__version__,
	"state_dict": model.state_dict()
}, f'{model_path}/epoch-1.ckpt')
# using checkpoint to save models after each epoch
checkpoint = pl.callbacks.ModelCheckpoint(dirpath = model_path, filename = 'epoch{epoch:02d}', auto_insert_metric_name = False, save_on_train_epoch_end = True, save_top_k = -1)
# saving gpu stats
gpu_stats = pl.callbacks.DeviceStatsMonitor()

# -------------------------------
# Training model
# -------------------------------

trainer = pl.Trainer(default_root_dir = log_path, callbacks = [gpu_stats, checkpoint], deterministic = True, accelerator = 'gpu', devices = 1, strategy = 'auto', num_nodes = 1, max_epochs = 100)
#trainer = pl.Trainer(default_root_dir = log_path, callbacks = [gpu_stats, checkpoint], deterministic = True, accelerator = 'gpu', devices = 4, strategy = 'dpp', num_nodes = 1, max_epochs = 100)
trainer.fit(model, train_loader)

## Test Model

In [None]:
import os
import os, psutil
process = psutil.Process(os.getpid())

import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
from humanfriendly import format_size
from functools import reduce
from cornet.extend_cornet_z import ExtendedCORnet

def main(args):

    seed = 0
    pl.seed_everything(seed)

    # -------------------------------
    # Parameters
    # -------------------------------

    n_max = 5

    # -------------------------------
    # Paths where to load/save data
    # -------------------------------

    # path where activities are saved
    activity_path = project_path+'/Script/result/activity/cornetz_sgd'
    os.makedirs(f'{activity_path}', exist_ok=True)
    # path where accuracies are saved
    accuracy_path = project_path+'/Script/result/accuracy/cornetz_sgd'
    os.makedirs(f'{accuracy_path}', exist_ok=True)
    # path where log of training are saved
    log_path = project_path+'/Script/result/log/test'
    os.makedirs(f'{log_path}', exist_ok=True)
    # path containing the dataset
    test_path = project_path+'/Script/images/Testolin_DeWind/test'
    # path containing the model
    model_path = project_path+'/Script/result/model/cornetz_sgd'

    # -------------------------------
    # Test dataset
    # -------------------------------

    # Write transform for image
    data_transform = transforms.Compose([
        # Resize the images to 64x64
        transforms.Resize(size=(64, 64)),
        # Flip the images randomly on the horizontal
        transforms.RandomHorizontalFlip(p=0.5), # p = probability of flip, 0.5 = 50% chance
        # Turn the image into a torch.Tensor
        transforms.ToTensor() # this also converts all pixel values from 0 to 255 to be between 0.0 and 1.0
    ])

    test_data = datasets.ImageFolder(root=test_path, transform=data_transform)
    test_loader = DataLoader(dataset=test_data, batch_size = 32, shuffle = False, num_workers = 4, pin_memory = True)

    # -------------------------------
    # Saving model
    # -------------------------------

    def hook_output(m, i, o):
        activity[m].append(o.cpu())

    class LabelConditionCallback(Callback):
        def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx=0):
            label.append(batch[1].cpu())
            condition.append(batch[2].cpu())
            param.append(batch[3].cpu())

    epochs = args.epochs
    accuracy = np.zeros(len(epochs))
    trainer = pl.Trainer(default_root_dir=log_path, callbacks = [LabelConditionCallback()], deterministic=True, devices="auto", accelerator="auto")
    model  = ExtendedCORnet(out_features = n_max)

    modules = [getattr(model.model.module, m).output for m in ["V1", "V2", "V4", "IT", "decoder"]]
    module_names = {getattr(model.model.module, m).output:m for m in ["V1", "V2", "V4", "IT", "decoder"]}
    times = {getattr(model.model.module, m).output:getattr(model.model.module, m).times if hasattr(getattr(model.model.module, m), "times") else 1 for m in ["V1", "V2", "V4", "IT", "decoder"]}
    for m in modules:
        m.register_forward_hook(hook_output)

    # -------------------------------
    # Testing model
    # -------------------------------

    for i, epoch in enumerate(epochs):
        os.makedirs(f'{activity_path}/epoch{epoch:02}', exist_ok=True)
        checkpoint = torch.load(f'{model_path}/epoch{epoch:02}.ckpt')
        model.load_state_dict(checkpoint['state_dict'])
        activity = {}
        label = []
        condition = []
        param = []
        for m in modules:
            activity[m]= []
        metrics, = trainer.test(model, test_loader)
        label = torch.cat(label).cpu().numpy()[:, None]
        if not os.path.exists(f'{activity_path}/epoch{epoch:02}/label.npz'):
            np.savez_compressed(f'{activity_path}/epoch{epoch:02}/label.npz', label = label)
            print(f'label at epoch {epoch} saved in .npz ({format_size(process.memory_info().rss)})')
        else:
            print(f'label at epoch {epoch} already saved in .npz ({format_size(process.memory_info().rss)})')
        condition = torch.cat(condition).cpu().numpy()[:, None]
        if not os.path.exists(f'{activity_path}/epoch{epoch:02}/condition.npz'):
            np.savez_compressed(f'{activity_path}/epoch{epoch:02}/condition.npz', condition = condition)
            print(f'condition at epoch {epoch} saved in .npz ({format_size(process.memory_info().rss)})')
        else:
            print(f'condition at epoch {epoch} already saved in .npz ({format_size(process.memory_info().rss)})')
        param = torch.cat(param).cpu().numpy()[:, None]
        if not os.path.exists(f'{activity_path}/epoch{epoch:02}/param.npz'):
            np.savez_compressed(f'{activity_path}/epoch{epoch:02}/param.npz', param = param)
            print(f'param at epoch {epoch} saved in .npz ({format_size(process.memory_info().rss)})')
        else:
            print(f'param at epoch {epoch} already saved in .npz ({format_size(process.memory_info().rss)})')
        print(f'Test finished ({format_size(process.memory_info().rss)})')
        accuracy[i] = metrics['test_acc_epoch']
        for m in modules:
            if not os.path.exists(f'{activity_path}/epoch{epoch:02}/{module_names[m]}.npz'):
                print(f'starting saving {module_names[m]} at epoch {epoch} ({format_size(process.memory_info().rss)})')
                tmp = torch.stack([torch.cat(activity[m][i::times[m]]) for i in range(times[m])], axis = 1).numpy()
                print(f'tmp created ({format_size(process.memory_info().rss)})')
                del activity[m]
                print(f'activity[m] removed ({format_size(process.memory_info().rss)})')
                print(module_names[m], tmp.shape)
                data_dict = {}
                data_dict[module_names[m]] = tmp
                print(f'data_dict created ({format_size(process.memory_info().rss)})')
                np.savez_compressed(f'{activity_path}/epoch{epoch:02}/{module_names[m]}.npz', **data_dict)
                print(f'{module_names[m]} at epoch {epoch} saved in .npz ({format_size(process.memory_info().rss)})')
                del tmp
                del data_dict
            else:
                print(f'{module_names[m]} at epoch {epoch} already saved in .npz ({format_size(process.memory_info().rss)})')

    epoch_accuracy = np.zeros(len(epochs), dtype = np.dtype([('epoch', np.int64, 1), ('accuracy', np.float64, 1)]))
    epoch_accuracy['epoch'] = epochs
    epoch_accuracy['accuracy'] = accuracy
    os.makedirs(f'{accuracy_path}', exist_ok=True)
    np.save(f'{accuracy_path}/epochs_{epochs}_accuracy.npy', epoch_accuracy)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description = 'Test models at different epochs')
    parser.add_argument('--epochs', metavar = 'E', type = int, nargs = "+", help = 'list of epochs to test')
    args = parser.parse_args()
    main(args)

In [None]:
import os
import zipfile
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import pytorch_lightning as pl
from functools import reduce
from cornet.extend_cornet_z import ExtendedCORnet

seed = 0
pl.seed_everything(seed)

# -------------------------------
# Parameters
# -------------------------------

# ESOS: 5
# Testolin: 10
n_max = 10 # max number of dots

# -------------------------------
# Paths where to load/save data
# -------------------------------

# path containing the dataset
#test_path = project_path+'/Script/images/ESOS/test'
#test_path = project_path+'/Script/images/Testolin_DeWind/test'
#test_path = project_path+'/Script/images/Testolin_Natural/test'
test_path = project_path+'/Script/images/Testolin_ISA2/test'
# path containing the model
model_path = project_path+'/Script/result/model/cornetz_sgd'

# -------------------------------
# Test dataset
# -------------------------------

# Write transform for image
data_transform = transforms.Compose([
    # Resize the images to 64x64
    transforms.Resize(size=(300, 300)),
    # Flip the images randomly on the horizontal
    transforms.RandomHorizontalFlip(p=0.5), # p = probability of flip, 0.5 = 50% chance
    # Turn the image into a torch.Tensor
    transforms.ToTensor() # this also converts all pixel values from 0 to 255 to be between 0.0 and 1.0
])

#test_data = ZipDataset(test_path, transform=data_transform)
test_data = datasets.ImageFolder(root=test_path, transform=data_transform)
test_loader = DataLoader(dataset=test_data, batch_size = 32, shuffle = False, num_workers = 4, pin_memory = True)

# -------------------------------
# Initializing model
# -------------------------------

model  = ExtendedCORnet(out_features = n_max)

checkpoint = torch.load(f'{model_path}/cornetz_testolin_isa2.ckpt')
model.load_state_dict(checkpoint['state_dict'])

trainer = pl.Trainer(deterministic=True, devices="auto", accelerator="auto")
metrics, = trainer.test(model, test_loader)