In [1]:
import os
os.environ["WANDB_API_KEY"] = "dcf9600e0485401cbb0ddbb0f7be1c70f96b32ef"
os.environ["WANDB_MODE"] = "disabled"
import argparse
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
import ipdb
import wandb
from sigma_layer import SigmaLinear, SigmaConv, SigmaView
from utils import get_dataset, gradient_centralization, normalize_along_axis, get_activation_function, compute_SCL_loss
import datetime

In [2]:
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
# Training scheme group
method_parser = parser.add_argument_group("Method")
method_parser.add_argument('--method', type=str, default='BP', choices=['SIGMA', 'BP', 'FA'])
method_parser.add_argument('--architecture', type=str, default='lenet', choices=['lenet', 'vgg'])
method_parser.add_argument('--actfunc', type=str, default='elu', choices=['tanh', 'elu', 'relu'])
method_parser.add_argument('--conv_dim', type=int, default=32, choices=[32, 64])
# Dataset group
dataset_parser = parser.add_argument_group('Dataset')
dataset_parser.add_argument('--dataset', type=str, default='CIFAR10', choices=['MNIST', 'CIFAR10'])
dataset_parser.add_argument('--batchsize', type=int, default=64)
dataset_parser.add_argument('--splitratio', type=float, default=0.2)
# Training group # LR, optimizer, weight_decay, momentum
training_parser = parser.add_argument_group('Training')
training_parser.add_argument('--epochs', type=int, default=100)
training_parser.add_argument('--lr', type=float, default=0.1)
training_parser.add_argument('--optimizer', type=str, default='SGD', choices=['RMSprop', 'Adam', 'SGD'])
# Seed group
seed_parser = parser.add_argument_group('Seed')
seed_parser.add_argument('--seed', type=int, default=2023)
args, _ = parser.parse_known_args()

# Set run_name
run_name = f"{args.dataset}_{args.method}_conv{args.conv_dim}-act{args.actfunc}_{args.optimizer}_{args.lr}_{args.seed}"
time_stamp = datetime.datetime.now().strftime("%m-%d_%H-%M-%S")

# Set wandb
wandb.init(
    project="opt-sigma",
    name=run_name,
    # track hyperparameters and run metadata
    config={
    "algorithm": args.method,
    "architecture": "SimpleCNN",
    "dataset": args.dataset,
    "epochs": args.epochs,
    "lr": args.lr,
    "optimizer": args.optimizer,
    "seed": args.seed,
    "conv_dim": args.conv_dim,
    "actfunc": args.actfunc,
    }
)

print(f"Run name: {run_name}")

# Set seed
torch.manual_seed(args.seed), np.random.seed(args.seed)

# Set device
device = torch.device("cpu") if not torch.cuda.is_available() else torch.device("cuda")

# Get dataset
train_loader, val_loader, test_loader = get_dataset(args)

class SigmaModel_SimpleCNN(nn.Module):
    def __init__(self, args):
        super(SigmaModel_SimpleCNN, self).__init__()
        if args.dataset == "CIFAR10":
            self.LN1 = torch.nn.LayerNorm((3,32,32), elementwise_affine=False) 
            self.LN2 = torch.nn.LayerNorm((args.conv_dim,16,16), elementwise_affine=False) 
            self.conv1 = SigmaConv(3, args.conv_dim, 3, args)
            self.conv2 = SigmaConv(args.conv_dim, 64, 3, args)
            self.view1 = SigmaView((64, 8, 8), 64 * 8 * 8)
            self.fc1 = SigmaLinear(64 * 8 * 8, 128, args)
            self.fc2 = SigmaLinear(128, 10, args)
            
        elif args.dataset == "MNIST":
            if args.archit
            self.LN1 = torch.nn.LayerNorm((1,28,28), elementwise_affine=False) 
            self.conv1 = SigmaConv(1, 32, 5, args)
            self.conv2 = SigmaConv(args.conv_dim, 64, 3, args)
            self.view1 = SigmaView((64, 7, 7), 64 * 7 * 7)
            self.fc1 = SigmaLinear(64 * 7 * 7, 128, args)
            self.fc2 = SigmaLinear(128, 10, args)
                
        self.forward_params = list()
        self.backward_params = list()
        for layer in [self.conv1, self.conv2, self.fc1, self.fc2]:
            forward_params, backward_params = layer.get_parameters()
            self.forward_params += forward_params
            self.backward_params += backward_params

    def forward(self, x, detach_grad=False, return_activations=True):   
        # x = self.LN1(x)
        a1 = self.conv1(x, detach_grad)
        a2 = self.conv2(a1, detach_grad)
        a2 = self.view1(a2, detach_grad)
        a3 = self.fc1(a2, detach_grad)
        a4 = self.fc2(a3, detach_grad)
        return [a1, a2, a3, a4]
        
    def reverse(self, target, detach_grad=True, return_activations=True):
        if target.shape == torch.Size([10]): 
            target = F.one_hot(target, num_classes=10).float().to(target.device)
        b3 = self.fc2.reverse(target, detach_grad)
        b2 = self.fc1.reverse(b3, detach_grad)
        b2 = self.view1.reverse(b2, detach_grad)
        b1 = self.conv2.reverse(b2, detach_grad)
        return [b1, b2, b3, target]


class SigmaLoss(nn.Module):
    def __init__(self, args):
        super(SigmaLoss, self).__init__()
        self.args = args
        self.final_criteria = nn.CrossEntropyLoss()
        self.local_criteria = compute_SCL_loss
        self.method = args.method
        
    def forward(self, activations, signals, target, method="final"):
        if method == "local":
            loss = list()
            for act, sig in zip(activations[:-1], signals[:-1]):
                loss += [self.local_criteria(act, sig, target)]
            loss += [self.final_criteria(activations[-1], target)]
            return sum(loss), loss[-1].item()
        elif method == "final":
            loss = self.final_criteria(activations[-1], target) * 1.1 - 0.1
            return loss, loss.item()
        
model = SigmaModel_SimpleCNN(args)
model.to(device)
if args.optimizer == "SGD": 
    forward_optimizer = optim.SGD(model.forward_params, lr=args.lr)
    backward_optimizer = optim.SGD(model.forward_params, lr=args.lr)
elif args.optimizer == "RMSprop": forward_optimizer = optim.RMSprop(model.forward_params, lr=args.lr, weight_decay=0)
elif args.optimizer == "Adam": forward_optimizer = optim.Adam(model.forward_params, lr=args.lr)
criteria = SigmaLoss(args)
    
with torch.no_grad():
    signals = model.reverse(torch.Tensor([0,1,2,3,4,5,6,7,8,9]).long().to(device), return_activations=True)
    
best_val_loss = float('inf')

print("No LN1")
for epoch in tqdm(range(args.epochs)):
    
    train_loss, train_counter = 0, 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        if batch_idx > 100: continue
        if args.method == "SIGMA":
            activations = model(data.to(device), detach_grad=True)
            signals = model.reverse(torch.Tensor([0,1,2,3,4,5,6,7,8,9]).long().to(device), detach_grad=True)
            loss, loss_item = criteria(activations, signals, target.to(device), method="local")
        elif args.method == "BP":
            activations = model(data.to(device), detach_grad=False)
            loss, loss_item = criteria(activations, signals, target.to(device), method="final")
        forward_optimizer.zero_grad(), backward_optimizer.zero_grad(), loss.backward()
        gradient_centralization(model), forward_optimizer.step(), backward_optimizer.step()
        train_loss += loss_item * len(data)
        train_counter += len(data)

    wandb.log({'train_loss': train_loss / train_counter}, step=epoch)
    
    for batch_idx, (data, target) in enumerate(train_loader):
        if batch_idx > 100: continue
        activations = model(data.to(device), detach_grad=True)
        loss, loss_item = criteria(activations, signals, target.to(device), method="final")
        forward_optimizer.zero_grad(), loss.backward(), forward_optimizer.step()

    # Validation
    val_correct, val_loss, val_counter = 0, 0, 0
    with torch.no_grad():
        for data, target in val_loader:
            val_counter += len(data)
            if args.method == "SIGMA":
                activations = model(data.to(device), detach_grad=True)
                _, loss_item = criteria(activations, signals, target.to(device), method="local")
                
            elif args.method == "BP":
                activations = model(data.to(device), detach_grad=True)
                _, loss_item = criteria(activations, signals, target.to(device), method="final")
            prediction = activations[-1].detach()
            _, predicted = torch.max(prediction, 1)
            val_correct += (predicted == target.to(device)).sum().item()
            val_loss += loss_item * len(data)

    wandb.log({'val_loss': val_loss / val_counter, 
               'val_acc': val_correct / val_counter, 
               }, step=epoch)
    
    print(f"""Epoch {epoch} | train loss {train_loss / train_counter:.4f} | val loss {val_loss / val_counter:.4f} | val acc {100 * val_correct / val_counter:.4f}""")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model)
        torch.save(best_model.state_dict(), f'./saved_models/{run_name}-{time_stamp}.pt')
        

# Eval on Test Set by loading the best model 
model.load_state_dict(torch.load(f'./saved_models/{run_name}-{time_stamp}.pt'))
model.eval()
correct, total = 0, 0
test_loss, test_counter = 0, 0
with torch.no_grad():
    for data, target in test_loader:
        outputs = model(data.to(device), detach_grad=False)
        _, predicted = torch.max(outputs.data, 1)
        correct += (predicted == target.to(device)).sum().item()
        loss, loss_item = criteria(activations, signals, target.to(device), method="final")
        test_loss += loss_item * len(data)
        test_counter += len(data)

wandb.log({'test_loss': test_loss / test_counter,
           'test_acc': 100 * correct / test_counter})

print(f'Epoch: {epoch}, Test Accuracy: {100 * correct / test_counter:.4f}%')

Run name: CIFAR10_BP_conv32-actelu_SGD_0.1_2023
Files already downloaded and verified
Files already downloaded and verified
No LN1


  1%|▍                                          | 1/100 [00:12<21:00, 12.73s/it]

Epoch 0 | train loss 1.9396 | val loss 1.6513 | val acc 43.7500


  2%|▊                                          | 2/100 [00:25<20:41, 12.67s/it]

Epoch 1 | train loss 1.6708 | val loss 1.4639 | val acc 49.6400


  3%|█▎                                         | 3/100 [00:37<20:02, 12.40s/it]

Epoch 2 | train loss 1.4924 | val loss 1.3220 | val acc 54.5500


  4%|█▋                                         | 4/100 [00:50<20:18, 12.69s/it]

Epoch 3 | train loss 1.3795 | val loss 1.2174 | val acc 58.1300


  5%|██▏                                        | 5/100 [01:03<20:24, 12.89s/it]

Epoch 4 | train loss 1.2958 | val loss 1.1654 | val acc 59.7500


  6%|██▌                                        | 6/100 [01:17<20:28, 13.07s/it]

Epoch 5 | train loss 1.2253 | val loss 1.1062 | val acc 61.9300


  7%|███                                        | 7/100 [01:30<20:22, 13.15s/it]

Epoch 6 | train loss 1.1354 | val loss 1.0653 | val acc 63.7200


  8%|███▍                                       | 8/100 [01:44<20:24, 13.31s/it]

Epoch 7 | train loss 1.0765 | val loss 1.0349 | val acc 63.9500


  9%|███▊                                       | 9/100 [01:59<21:12, 13.99s/it]

Epoch 8 | train loss 1.0156 | val loss 1.0118 | val acc 65.2200


 10%|████▏                                     | 10/100 [02:12<20:23, 13.60s/it]

Epoch 9 | train loss 0.9628 | val loss 0.9937 | val acc 66.6100


 11%|████▌                                     | 11/100 [02:25<19:51, 13.38s/it]

Epoch 10 | train loss 0.9528 | val loss 0.9734 | val acc 67.1600


 12%|█████                                     | 12/100 [02:37<19:13, 13.11s/it]

Epoch 11 | train loss 0.9067 | val loss 0.9665 | val acc 67.3100


 13%|█████▍                                    | 13/100 [02:50<18:59, 13.10s/it]

Epoch 12 | train loss 0.8450 | val loss 0.9726 | val acc 67.7900


 14%|█████▉                                    | 14/100 [03:03<18:39, 13.02s/it]

Epoch 13 | train loss 0.8469 | val loss 0.9517 | val acc 68.1300


 15%|██████▎                                   | 15/100 [03:16<18:27, 13.03s/it]

Epoch 14 | train loss 0.7895 | val loss 0.9602 | val acc 68.0400


 16%|██████▋                                   | 16/100 [03:29<18:06, 12.94s/it]

Epoch 15 | train loss 0.7631 | val loss 0.9541 | val acc 68.4100


 17%|███████▏                                  | 17/100 [03:42<17:50, 12.90s/it]

Epoch 16 | train loss 0.7338 | val loss 0.9638 | val acc 68.1600


 18%|███████▌                                  | 18/100 [03:54<17:12, 12.59s/it]

Epoch 17 | train loss 0.7365 | val loss 0.9656 | val acc 68.7200


 19%|███████▉                                  | 19/100 [04:04<16:16, 12.06s/it]

Epoch 18 | train loss 0.6860 | val loss 1.0191 | val acc 66.9800


 20%|████████▍                                 | 20/100 [04:16<15:46, 11.83s/it]

Epoch 19 | train loss 0.6577 | val loss 0.9972 | val acc 68.9600


 21%|████████▊                                 | 21/100 [04:27<15:24, 11.70s/it]

Epoch 20 | train loss 0.6422 | val loss 0.9870 | val acc 68.6900


 22%|█████████▏                                | 22/100 [04:39<15:25, 11.86s/it]

Epoch 21 | train loss 0.6112 | val loss 1.0002 | val acc 69.1400


 23%|█████████▋                                | 23/100 [04:50<14:50, 11.57s/it]

Epoch 22 | train loss 0.6237 | val loss 1.0037 | val acc 69.1800


 24%|██████████                                | 24/100 [05:01<14:19, 11.31s/it]

Epoch 23 | train loss 0.5648 | val loss 1.0337 | val acc 67.9400


 25%|██████████▌                               | 25/100 [05:12<14:09, 11.33s/it]

Epoch 24 | train loss 0.5250 | val loss 1.0944 | val acc 67.9500


 26%|██████████▉                               | 26/100 [05:25<14:20, 11.63s/it]

Epoch 25 | train loss 0.5348 | val loss 1.0584 | val acc 68.7100


 27%|███████████▎                              | 27/100 [05:36<13:53, 11.42s/it]

Epoch 26 | train loss 0.5009 | val loss 1.1337 | val acc 68.2100


 28%|███████████▊                              | 28/100 [05:46<13:27, 11.21s/it]

Epoch 27 | train loss 0.5112 | val loss 1.0826 | val acc 68.6200


 29%|████████████▏                             | 29/100 [05:57<13:07, 11.09s/it]

Epoch 28 | train loss 0.4662 | val loss 1.1563 | val acc 68.3400


 30%|████████████▌                             | 30/100 [06:08<12:57, 11.11s/it]

Epoch 29 | train loss 0.4643 | val loss 1.0898 | val acc 68.8800


 31%|█████████████                             | 31/100 [06:20<13:00, 11.31s/it]

Epoch 30 | train loss 0.4526 | val loss 1.1297 | val acc 68.3700


 32%|█████████████▍                            | 32/100 [06:31<12:48, 11.30s/it]

Epoch 31 | train loss 0.4051 | val loss 1.1499 | val acc 68.0900


 33%|█████████████▊                            | 33/100 [06:43<12:36, 11.29s/it]

Epoch 32 | train loss 0.4143 | val loss 1.1948 | val acc 68.8700


 34%|██████████████▎                           | 34/100 [06:53<12:15, 11.15s/it]

Epoch 33 | train loss 0.3874 | val loss 1.2175 | val acc 67.8000


 35%|██████████████▋                           | 35/100 [07:05<12:17, 11.35s/it]

Epoch 34 | train loss 0.4226 | val loss 1.1485 | val acc 68.1100


 36%|███████████████                           | 36/100 [07:17<12:12, 11.44s/it]

Epoch 35 | train loss 0.3450 | val loss 1.2199 | val acc 67.9400


 37%|███████████████▌                          | 37/100 [07:29<12:11, 11.61s/it]

Epoch 36 | train loss 0.3365 | val loss 1.2696 | val acc 68.7100


 38%|███████████████▉                          | 38/100 [07:41<11:59, 11.61s/it]

Epoch 37 | train loss 0.3262 | val loss 1.2877 | val acc 68.6100


 39%|████████████████▍                         | 39/100 [07:53<11:57, 11.77s/it]

Epoch 38 | train loss 0.3279 | val loss 1.3221 | val acc 67.5000


 40%|████████████████▊                         | 40/100 [08:05<11:50, 11.84s/it]

Epoch 39 | train loss 0.3243 | val loss 1.3859 | val acc 67.4000


 41%|█████████████████▏                        | 41/100 [08:18<11:56, 12.15s/it]

Epoch 40 | train loss 0.3322 | val loss 1.3122 | val acc 68.8700


 42%|█████████████████▋                        | 42/100 [08:31<12:01, 12.44s/it]

Epoch 41 | train loss 0.2915 | val loss 1.3888 | val acc 67.7300


 43%|██████████████████                        | 43/100 [08:44<12:08, 12.78s/it]

Epoch 42 | train loss 0.3223 | val loss 1.3765 | val acc 67.5700


 44%|██████████████████▍                       | 44/100 [08:58<12:04, 12.93s/it]

Epoch 43 | train loss 0.3043 | val loss 1.4347 | val acc 67.9500


 45%|██████████████████▉                       | 45/100 [09:11<11:59, 13.08s/it]

Epoch 44 | train loss 0.2701 | val loss 1.3834 | val acc 67.8600


 46%|███████████████████▎                      | 46/100 [09:24<11:49, 13.13s/it]

Epoch 45 | train loss 0.2766 | val loss 1.4277 | val acc 67.6700


 47%|███████████████████▋                      | 47/100 [09:37<11:27, 12.98s/it]

Epoch 46 | train loss 0.2950 | val loss 1.4576 | val acc 67.8200


 48%|████████████████████▏                     | 48/100 [09:49<11:09, 12.88s/it]

Epoch 47 | train loss 0.2585 | val loss 1.5213 | val acc 67.8600


 49%|████████████████████▌                     | 49/100 [10:03<11:01, 12.97s/it]

Epoch 48 | train loss 0.2624 | val loss 1.5946 | val acc 67.8300


 50%|█████████████████████                     | 50/100 [10:16<10:54, 13.10s/it]

Epoch 49 | train loss 0.2543 | val loss 1.5304 | val acc 68.0300


 51%|█████████████████████▍                    | 51/100 [10:29<10:46, 13.20s/it]

Epoch 50 | train loss 0.2607 | val loss 1.5437 | val acc 67.7500


 52%|█████████████████████▊                    | 52/100 [10:42<10:26, 13.05s/it]

Epoch 51 | train loss 0.2576 | val loss 1.6690 | val acc 68.3300


 53%|██████████████████████▎                   | 53/100 [10:55<10:13, 13.05s/it]

Epoch 52 | train loss 0.2187 | val loss 1.6816 | val acc 68.4800


 54%|██████████████████████▋                   | 54/100 [11:09<10:06, 13.19s/it]

Epoch 53 | train loss 0.2654 | val loss 1.6061 | val acc 67.6800


 55%|███████████████████████                   | 55/100 [11:21<09:37, 12.83s/it]

Epoch 54 | train loss 0.2071 | val loss 1.5964 | val acc 67.4300


 56%|███████████████████████▌                  | 56/100 [11:33<09:23, 12.81s/it]

Epoch 55 | train loss 0.2389 | val loss 1.6044 | val acc 67.3700


 57%|███████████████████████▉                  | 57/100 [11:46<09:10, 12.81s/it]

Epoch 56 | train loss 0.2306 | val loss 1.7868 | val acc 68.0000


 58%|████████████████████████▎                 | 58/100 [11:59<08:54, 12.73s/it]

Epoch 57 | train loss 0.2703 | val loss 1.6642 | val acc 67.6800


 59%|████████████████████████▊                 | 59/100 [12:11<08:31, 12.49s/it]

Epoch 58 | train loss 0.2146 | val loss 1.6428 | val acc 68.0700


 60%|█████████████████████████▏                | 60/100 [12:23<08:20, 12.52s/it]

Epoch 59 | train loss 0.1755 | val loss 1.7637 | val acc 67.3400


 61%|█████████████████████████▌                | 61/100 [12:36<08:13, 12.65s/it]

Epoch 60 | train loss 0.2153 | val loss 1.7879 | val acc 67.5600


 61%|█████████████████████████▌                | 61/100 [12:43<08:08, 12.52s/it]


KeyboardInterrupt: 