In [1]:
import os
os.environ["WANDB_API_KEY"] = "dcf9600e0485401cbb0ddbb0f7be1c70f96b32ef"
os.environ["WANDB_MODE"] = "disabled"
import argparse
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm
import matplotlib.pyplot as plt
import ipdb
import wandb

from sigma_layer import SigmaLinear, SigmaConv, SigmaView
from utils import get_dataset, compute_SCL_loss
import datetime

In [2]:
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
# Training scheme group
method_parser = parser.add_argument_group("Method")
method_parser.add_argument('--method', type=str, default='BP', choices=['SIGMA', 'BP', 'FA'])
method_parser.add_argument('--architecture', type=str, default='lenet', choices=['lenet', 'vgg'])
method_parser.add_argument('--actfunc', type=str, default='elu', choices=['tanh', 'elu', 'relu'])
method_parser.add_argument('--conv_dim', type=int, default=32, choices=[32, 64])
# Dataset group
dataset_parser = parser.add_argument_group('Dataset')
dataset_parser.add_argument('--dataset', type=str, default='CIFAR10', choices=['MNIST', 'CIFAR10'])
dataset_parser.add_argument('--batchsize', type=int, default=128)
dataset_parser.add_argument('--splitratio', type=float, default=0.1)
# Training group # LR, optimizer, weight_decay, momentum
training_parser = parser.add_argument_group('Training')
training_parser.add_argument('--epochs', type=int, default=100)
training_parser.add_argument('--lr', type=float, default=0.05)
training_parser.add_argument('--optimizer', type=str, default='SGD', choices=['RMSprop', 'Adam', 'SGD'])
# Seed group
seed_parser = parser.add_argument_group('Seed')
seed_parser.add_argument('--seed', type=int, default=42)
args, _ = parser.parse_known_args()

# Set run_name
run_name = f"{args.dataset}_{args.method}_conv{args.conv_dim}-act{args.actfunc}_{args.optimizer}_{args.lr}_{args.seed}"
time_stamp = datetime.datetime.now().strftime("%m-%d_%H-%M-%S")

# Set wandb
wandb.init(
    project="opt-sigma",
    name=run_name,
    # track hyperparameters and run metadata
    config={
    "algorithm": args.method,
    "architecture": "SimpleCNN",
    "dataset": args.dataset,
    "epochs": args.epochs,
    "lr": args.lr,
    "optimizer": args.optimizer,
    "seed": args.seed,
    "conv_dim": args.conv_dim,
    "actfunc": args.actfunc,
    }
)

print(f"Run name: {run_name}")

# Set seed
torch.manual_seed(args.seed), np.random.seed(args.seed)

# Set device
device = torch.device("cpu") if not torch.cuda.is_available() else torch.device("cuda")

# Get dataset
train_loader, val_loader, test_loader = get_dataset(args)

class SigmaModel_SimpleCNN(nn.Module):
    def __init__(self, args):
        super(SigmaModel_SimpleCNN, self).__init__()
        if args.dataset == "CIFAR10":
            self.conv1 = SigmaConv(args, 3, 32, 5, 3, 2)
            self.conv2 = SigmaConv(args, 32, 64, 3, 3, 2)
            self.view1 = SigmaView((64, 8, 8), 64 * 8 * 8)
            self.fc1 = SigmaLinear(args, 64 * 8 * 8, 512)
            self.fc2 = SigmaLinear(args, 512, 10)
            
        elif args.dataset == "MNIST":
            if args.architecture == "lenet":
                self.conv1 = SigmaConv(args, 1, 32, 5, 2, 3)
                self.conv2 = SigmaConv(args, 32, 64, 5, 2, 3)
                self.view1 = SigmaView((64, 4, 4), 64 * 4 * 4)
                self.fc1 = SigmaLinear(args, 64 * 4 * 4, 512)
                self.fc2 = SigmaLinear(args, 512, 10)
                
        self.forward_params = list()
        self.backward_params = list()
        for layer in [self.conv1, self.conv2, self.fc1, self.fc2]:
            forward_params, backward_params = layer.get_parameters()
            self.forward_params += forward_params
            self.backward_params += backward_params

    def forward(self, x, detach_grad=False, return_activations=True):   
        a1 = self.conv1(x, detach_grad)
        a2 = self.conv2(a1, detach_grad)
        a2 = self.view1(a2, detach_grad)
        a3 = self.fc1(a2, detach_grad)
        a4 = self.fc2(a3, detach_grad)
        return [a1, a2, a3, a4]
        
    def reverse(self, target, detach_grad=True, return_activations=True):
        if target.shape == torch.Size([10]): 
            target = F.one_hot(target, num_classes=10).float().to(target.device)
        b3 = self.fc2.reverse(target, detach_grad)
        b2 = self.fc1.reverse(b3, detach_grad)
        b2 = self.view1.reverse(b2, detach_grad)
        b1 = self.conv2.reverse(b2, detach_grad)
        return [b1, b2, b3, target]


class SigmaLoss(nn.Module):
    def __init__(self, args):
        super(SigmaLoss, self).__init__()
        self.args = args
        self.final_criteria = nn.CrossEntropyLoss()
        self.local_criteria = compute_SCL_loss
        self.method = args.method
        
    def forward(self, activations, signals, target, method="final"):
        if method == "local":
            loss = list()
            for act, sig in zip(activations[:-1], signals[:-1]):
                loss += [self.local_criteria(act, sig, target)]
            loss += [self.final_criteria(activations[-1], target)]
            return sum(loss), loss[-1].item()
        elif method == "final":
            loss = self.final_criteria(activations[-1], target)
            return loss, loss.item()
        
model = SigmaModel_SimpleCNN(args)
model.to(device)
if args.optimizer == "SGD": 
    forward_optimizer = optim.SGD(model.forward_params, lr=0.05, momentum=0.9)

forward_scheduler = CosineAnnealingLR(forward_optimizer, T_max=85, eta_min=1e-05)
criteria = SigmaLoss(args)
    
with torch.no_grad():
    signals = model.reverse(torch.Tensor([0,1,2,3,4,5,6,7,8,9]).long().to(device), return_activations=True)
    
best_val_loss = float('inf')

print("No LN1")
for epoch in range(args.epochs):
    train_loss, train_counter = 0, 0    
    for batch_idx, (data, target) in enumerate(train_loader):
        # if batch_idx > 100: continue
        forward_optimizer.zero_grad()
        activations = model(data.to(device), detach_grad=False)
        loss, loss_item = criteria(activations, signals, target.to(device), method="final")
        forward_optimizer.zero_grad(), loss.backward(), forward_optimizer.step()
        train_loss, train_counter = train_loss + loss_item * len(data), train_counter + len(data)
    forward_scheduler.step()
    wandb.log({'train_loss': train_loss / train_counter}, step=epoch)    

    # Validation
    val_correct, val_loss, val_counter = 0, 0, 0
    with torch.no_grad():
        for data, target in val_loader:            
            activations = model(data.to(device), detach_grad=True)
            _, loss_item = criteria(activations, signals, target.to(device), method="final")
            _, predicted = torch.max(activations[-1].detach(), 1)
            val_correct, val_loss, val_counter = val_correct + (predicted == target.to(device)).sum().item(), val_loss + loss_item * len(data), val_counter + len(data)

    wandb.log({'val_loss': val_loss / val_counter, 'val_acc': val_correct / val_counter}, step=epoch)
    
    print(f"""Epoch {epoch} | train loss {train_loss / train_counter:.4f} | val loss {val_loss / val_counter:.4f} | val acc {100 * val_correct / val_counter:.4f}""")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model)
        torch.save(best_model.state_dict(), f'./saved_models/{run_name}-{time_stamp}.pt')
        

# Eval on Test Set by loading the best model 
model.load_state_dict(torch.load(f'./saved_models/{run_name}-{time_stamp}.pt'))
model.eval()
correct, total = 0, 0
test_loss, test_counter = 0, 0
with torch.no_grad():
    for data, target in test_loader:
        outputs = model(data.to(device), detach_grad=False)
        _, predicted = torch.max(outputs.data, 1)
        correct += (predicted == target.to(device)).sum().item()
        loss, loss_item = criteria(activations, signals, target.to(device), method="final")
        test_loss += loss_item * len(data)
        test_counter += len(data)

wandb.log({'test_loss': test_loss / test_counter,
           'test_acc': 100 * correct / test_counter})

print(f'Epoch: {epoch}, Test Accuracy: {100 * correct / test_counter:.4f}%')

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


Run name: CIFAR10_BP_conv32-actelu_SGD_0.05_42
Files already downloaded and verified
Files already downloaded and verified
No LN1
Epoch 0 | train loss 1.6142 | val loss 1.4719 | val acc 47.8100
Epoch 1 | train loss 1.4809 | val loss 1.6353 | val acc 47.6900
Epoch 2 | train loss 1.6606 | val loss 1.7553 | val acc 48.3500
Epoch 3 | train loss 1.7787 | val loss 2.0160 | val acc 40.5300
Epoch 4 | train loss 1.8348 | val loss 1.8787 | val acc 47.1300
Epoch 5 | train loss 1.8683 | val loss 2.1641 | val acc 36.5500
Epoch 6 | train loss 2.2705 | val loss 2.3026 | val acc 9.0300
Epoch 7 | train loss 2.3026 | val loss 2.3026 | val acc 8.8100
Epoch 8 | train loss 2.3026 | val loss 2.3026 | val acc 8.9200
Epoch 9 | train loss 2.3026 | val loss 2.3026 | val acc 8.9000
Epoch 10 | train loss 2.3026 | val loss 2.3026 | val acc 8.7700
Epoch 11 | train loss 2.3026 | val loss 2.3026 | val acc 8.8100
Epoch 12 | train loss 2.3026 | val loss 2.3026 | val acc 8.7600
Epoch 13 | train loss 2.3026 | val loss 2.

KeyboardInterrupt: 

In [None]:
# With Cropping (edge)
# transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(3 * 0.2023, 3 * 0.1994, 3 * 0.2010)) # Following pl_bolt's
                                        
# Epoch 0 | train loss 1.4767 | val loss 1.0709 | val acc 62.5100
# Epoch 1 | train loss 1.0998 | val loss 0.9122 | val acc 68.3900
# Epoch 2 | train loss 0.9462 | val loss 0.8498 | val acc 70.5800
# Epoch 3 | train loss 0.8750 | val loss 0.8037 | val acc 72.3100
# Epoch 4 | train loss 0.8145 | val loss 0.7622 | val acc 73.4700
# Epoch 5 | train loss 0.7637 | val loss 0.7678 | val acc 72.9200
# Epoch 6 | train loss 0.7437 | val loss 0.7297 | val acc 74.7100
# Epoch 7 | train loss 0.7246 | val loss 0.7414 | val acc 74.4800
# Epoch 8 | train loss 0.6919 | val loss 0.7192 | val acc 75.2600
# Epoch 9 | train loss 0.6813 | val loss 0.7395 | val acc 74.7300
# Epoch 10 | train loss 0.6572 | val loss 0.7179 | val acc 75.7500
# Epoch 11 | train loss 0.6289 | val loss 0.7234 | val acc 75.7500
# Epoch 12 | train loss 0.6311 | val loss 0.7124 | val acc 76.4500
# Epoch 13 | train loss 0.6076 | val loss 0.7303 | val acc 75.5700
# Epoch 14 | train loss 0.6110 | val loss 0.7216 | val acc 75.9600
# Epoch 15 | train loss 0.5847 | val loss 0.7151 | val acc 75.9800
# Epoch 16 | train loss 0.5721 | val loss 0.7057 | val acc 76.8200
# Epoch 17 | train loss 0.5608 | val loss 0.7082 | val acc 76.3500
# Epoch 18 | train loss 0.5597 | val loss 0.7072 | val acc 77.0000
# Epoch 19 | train loss 0.5428 | val loss 0.6865 | val acc 77.6100
# Epoch 20 | train loss 0.5215 | val loss 0.7228 | val acc 76.8000

In [None]:
# Epoch 0 | train loss 1.6212 | val loss 1.2239 | val acc 57.7800
# Epoch 1 | train loss 1.1105 | val loss 0.9573 | val acc 67.1100
# Epoch 2 | train loss 0.9249 | val loss 0.8443 | val acc 70.5500
# Epoch 3 | train loss 0.8030 | val loss 0.7778 | val acc 72.5400
# Epoch 4 | train loss 0.7149 | val loss 0.7120 | val acc 75.0400
# Epoch 5 | train loss 0.6382 | val loss 0.6874 | val acc 75.8700
# Epoch 6 | train loss 0.5760 | val loss 0.6636 | val acc 76.4500
# Epoch 7 | train loss 0.5212 | val loss 0.6519 | val acc 77.5800
# Epoch 8 | train loss 0.4668 | val loss 0.6548 | val acc 77.4300
# Epoch 9 | train loss 0.4243 | val loss 0.6586 | val acc 77.5500
# Epoch 10 | train loss 0.3872 | val loss 0.6611 | val acc 78.0400
# Epoch 11 | train loss 0.3390 | val loss 0.6800 | val acc 77.9200
# Epoch 12 | train loss 0.3089 | val loss 0.7055 | val acc 77.8300
# Epoch 13 | train loss 0.2707 | val loss 0.7307 | val acc 77.7800
# Epoch 14 | train loss 0.2408 | val loss 0.7463 | val acc 77.9100
# Epoch 15 | train loss 0.2150 | val loss 0.7790 | val acc 78.4800

313