In [None]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet18
import pandas as pd
import bz2
import pickle
from tqdm import tqdm

# Test to load the SoftHebb package
from SoftHebb.dataset import make_data_loaders
from SoftHebb.model import load_layers, save_layers, HebbianOptimizer, AggregateOptim
from SoftHebb.engine import train_sup, evaluate_sup
from SoftHebb.train import check_dimension, training_config
from SoftHebb.utils import seed_init_fn, load_presets, load_config_dataset, CustomStepLR
from SoftHebb.log import Log

In [ ]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

params = {"preset":"2SoftMlpMNIST", "dataset_sup":"MNIST", "dataset_unsup":"MNIST",
          "seed":52,"model-name":"2SoftMlpMNIST", "training_mode":"simultaneous", "training_blocks":None,
          "resume":False, "save":False}

name_model = params["preset"]
blocks = load_presets(params["preset"])
dataset_sup_config = load_config_dataset(params["dataset_sup"], 0.8)
dataset_unsup_config = load_config_dataset(params["dataset_unsup"], 0.8)
dataset_sup_config['validation'] = True
if params["seed"] is not None:
    dataset_sup_config['seed'] = params["seed"]
    dataset_unsup_config['seed'] = params["seed"]

if dataset_sup_config['seed'] is not None:
    seed_init_fn(dataset_sup_config['seed'])

blocks = check_dimension(blocks, dataset_sup_config)

train_config = training_config(blocks, dataset_sup_config, dataset_unsup_config, params["training_mode"],
                               params["training_blocks"])

config = train_config['t1']

train_loader, val_loader, test_loader = make_data_loaders(dataset_sup_config, config['batch_size'], device)


In [ ]:
# Load the ResNet18 model without pretrained weights
model = resnet18(weights=None)

# Modify the first convolutional layer to accept 1-channel (grayscale) images
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

# Modify the final fully connected layer to output 10 classes (for MNIST)
model.fc = nn.Linear(model.fc.in_features, 10)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [ ]:
# Training loop
num_epochs = 10
model.to(device)

# Initialize lists to store metrics
train_losses, val_losses, train_accuracies, val_accuracies = [], [], [], []
best_val_accuracy = 0.0

for epoch in range(num_epochs):
    # Training phase
    model.train()
    running_loss, correct_train, total_train = 0.0, 0.0, 0.0
    for images, labels in tqdm(train_loader, desc="Training:"):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()
    
    train_loss = running_loss / len(train_loader)
    train_accuracy = 100 * correct_train / total_train
    train_losses.append(train_loss)
    train_accuracies.append(train_accuracy)
    
    # Validation phase
    model.eval()
    running_loss, correct_val, total_val = 0.0, 0.0, 0.0
    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc="Validation:"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_val += labels.size(0)
            correct_val += (predicted == labels).sum().item()
    
    val_loss = running_loss / len(val_loader)
    val_accuracy = 100 * correct_val / total_val
    val_losses.append(val_loss)
    val_accuracies.append(val_accuracy)
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%")
    
    # Save checkpoint if validation accuracy improves
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_accuracy': val_accuracy,
        }
        with bz2.BZ2File('best_model_checkpoint.pbz2', 'w') as f:
            pickle.dump(checkpoint, f)

# Save the metrics to a CSV file
metrics_df = pd.DataFrame({
    'Epoch': list(range(1, num_epochs + 1)),
    'Train Loss': train_losses,
    'Train Accuracy': train_accuracies,
    'Val Loss': val_losses,
    'Val Accuracy': val_accuracies
})

metrics_df.to_csv('training_metrics.csv', index=False)

In [ ]:
# Load the best model checkpoint
with bz2.BZ2File('best_model_checkpoint.pbz2', 'rb') as f:
    checkpoint = pickle.load(f)

model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
best_val_accuracy = checkpoint['val_accuracy']

print(f'Loaded checkpoint from epoch {epoch+1} with validation accuracy: {best_val_accuracy:.2f}%')

# Testing loop
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Testing:") :
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the model on the test images: {100 * correct / total:.2f}%')

In [ ]:
SoftHebb_model = load_layers(blocks, name_model, False)

In [ ]:
def train_model(
        final_epoch: int,
        print_freq: int,
        lr: float,
        folder_name: str,
        model,
        device,
        log,
        blocks,
        learning_mode: str = 'BP',
        save_batch: bool = True,
        save: bool = True,
        report=None,
        plot_fc=None,
        model_dir=None,
):
    """
    Hybrid training of one model, happens during simultaneous training mode
    """

    print('\n', '********** Hybrid learning of blocks %s **********' % blocks)

    optimizer_sgd = optim.Adam(
        model.parameters(), lr=lr)  # , weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss()
    hebbian_optimizer = HebbianOptimizer(model)
    scheduler = CustomStepLR(optimizer_sgd, final_epoch)
    optimizer = AggregateOptim((hebbian_optimizer, optimizer_sgd))
    log_batch = log.new_log_batch()
    
    for epoch in range(1, final_epoch + 1):
        measures, lr = train_sup(model, criterion, optimizer, train_loader, device, log_batch, learning_mode, blocks)

        if scheduler is not None:
            scheduler.step()

        if epoch % print_freq == 0 or epoch == final_epoch or epoch == 1:

            loss_test, acc_test = evaluate_sup(model, criterion, val_loader, device)

            log_batch = log.step(epoch, log_batch, loss_test, acc_test, lr, save=save_batch)

            if report is not None:
                _, train_loss, train_acc, test_loss, test_acc = log.data[-1]

                conv, R1 = model.convergence()
                report(train_loss=train_loss, train_acc=train_acc, test_loss=test_loss, test_acc=test_acc,
                       convergence=conv, R1=R1)

            else:
                log.verbose()
            if save:
                save_layers(model, folder_name, epoch, blocks, storing_path=model_dir)

            if plot_fc is not None:
                for block in blocks:
                    plot_fc(model, block)

In [ ]:
log = Log(train_config)
config = train_config['t1']

train_model(
    config['nb_epoch'],
    config['print_freq'],
    config['lr'],
    name_model,
    SoftHebb_model,
    device,
    log.sup['t1'],
    blocks=config['blocks'],
    save=False
)

In [ ]:
criterion = nn.CrossEntropyLoss()
loss_test, acc_test = evaluate_sup(SoftHebb_model, criterion, test_loader, device)

print(f'Accuracy of the model on the test images: {acc_test:.2f}%')