In [1]:
import os
os.environ["WANDB_API_KEY"] = "dcf9600e0485401cbb0ddbb0f7be1c70f96b32ef"
os.environ["WANDB_MODE"] = "disabled"
import argparse
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
import ipdb
import wandb
from sigma_layer import SigmaLinear, SigmaConv, SigmaConvT, SigmaView
from utils import get_dataset, gradient_centralization, normalize_along_axis, get_activation_function, compute_SCL_loss
import datetime

In [4]:
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
# Training scheme group
method_parser = parser.add_argument_group("Method")
method_parser.add_argument('--method', type=str, default='SIGMA', choices=['SIGMA', 'BP', 'FA'])
method_parser.add_argument('--actfunc', type=str, default='elu', choices=['tanh', 'elu', 'relu'])
method_parser.add_argument('--conv_dim', type=int, default=32, choices=[32, 64])
# Dataset group
dataset_parser = parser.add_argument_group('Dataset')
dataset_parser.add_argument('--dataset', type=str, default='CIFAR10', choices=['MNIST', 'CIFAR10'])
dataset_parser.add_argument('--batchsize', type=int, default=64)
dataset_parser.add_argument('--splitratio', type=float, default=0.2)
# Training group # LR, optimizer, weight_decay, momentum
training_parser = parser.add_argument_group('Training')
training_parser.add_argument('--epochs', type=int, default=100)
training_parser.add_argument('--lr', type=float, default=1)
training_parser.add_argument('--optimizer', type=str, default='SGD', choices=['RMSprop', 'Adam', 'SGD'])
# Seed group
seed_parser = parser.add_argument_group('Seed')
seed_parser.add_argument('--seed', type=int, default=2023)
args, _ = parser.parse_known_args()

# Set run_name
run_name = f"{args.dataset}_{args.method}_conv{args.conv_dim}-act{args.actfunc}_{args.optimizer}_{args.lr}_{args.seed}"
time_stamp = datetime.datetime.now().strftime("%m-%d_%H-%M-%S")

# Set wandb
wandb.init(
    project="opt-sigma",
    name=run_name,
    # track hyperparameters and run metadata
    config={
    "algorithm": args.method,
    "architecture": "SimpleCNN",
    "dataset": args.dataset,
    "epochs": args.epochs,
    "lr": args.lr,
    "optimizer": args.optimizer,
    "seed": args.seed,
    "conv_dim": args.conv_dim,
    "actfunc": args.actfunc,
    }
)

print(f"Run name: {run_name}")

# Set seed
torch.manual_seed(args.seed), np.random.seed(args.seed)

# Set device
device = torch.device("cpu") if not torch.cuda.is_available() else torch.device("cuda")

# Get dataset
train_loader, val_loader, test_loader = get_dataset(args)

class SigmaModel_SimpleCNN(nn.Module):
    def __init__(self, args):
        super(SigmaModel_SimpleCNN, self).__init__()
        if args.dataset == "CIFAR10":
            self.LN1 = torch.nn.LayerNorm((3,32,32), elementwise_affine=False) 
            self.LN2 = torch.nn.LayerNorm((args.conv_dim,16,16), elementwise_affine=False) 
            self.conv1 = SigmaConvT(3, args.conv_dim, 3, args)
            self.conv2 = SigmaConvT(args.conv_dim, 64, 3, args)
            self.view1 = SigmaView((64, 8, 8), 64 * 8 * 8)
            self.fc1 = SigmaLinear(64 * 8 * 8, 128, args)
            self.fc2 = SigmaLinear(128, 10, args)
            
        elif args.dataset == "MNIST":
            self.LN1 = torch.nn.LayerNorm((1,28,28), elementwise_affine=False) 
            self.LN2 = torch.nn.LayerNorm((args.conv_dim,14,14), elementwise_affine=False) 
            self.conv1 = SigmaConv(1, args.conv_dim, 3, args)
            self.conv2 = SigmaConv(args.conv_dim, 64, 3, args)
            self.view1 = SigmaView((64, 7, 7), 64 * 7 * 7)
            self.fc1 = SigmaLinear(64 * 7 * 7, 128, args)
            self.fc2 = SigmaLinear(128, 10, args)
                
        self.forward_params = list()
        self.backward_params = list()
        for layer in [self.conv1, self.conv2, self.fc1, self.fc2]:
            forward_params, backward_params = layer.get_parameters()
            self.forward_params += forward_params
            self.backward_params += backward_params

    def forward(self, x, detach_grad=False, return_activations=True):   
        # x = self.LN1(x)
        a1 = self.conv1(x, detach_grad)
        a2 = self.conv2(a1, detach_grad)
        a2 = self.view1(a2, detach_grad)
        a3 = self.fc1(a2, detach_grad)
        a4 = self.fc2(a3, detach_grad)
        return [a1, a2, a3, a4]
        
    def reverse(self, target, detach_grad=True, return_activations=True):
        if target.shape == torch.Size([10]): 
            target = F.one_hot(target, num_classes=10).float().to(target.device)
        b3 = self.fc2.reverse(target, detach_grad)
        b2 = self.fc1.reverse(b3, detach_grad)
        b2 = self.view1.reverse(b2, detach_grad)
        b1 = self.conv2.reverse(b2, detach_grad)
        return [b1, b2, b3, target]


class SigmaLoss(nn.Module):
    def __init__(self, args):
        super(SigmaLoss, self).__init__()
        self.args = args
        self.final_criteria = nn.CrossEntropyLoss()
        self.local_criteria = compute_SCL_loss
        self.method = args.method
        
    def forward(self, activations, signals, target, method="final"):
        if method == "local":
            loss = list()
            for act, sig in zip(activations[:-1], signals[:-1]):
                loss += [self.local_criteria(act, sig, target)]
            loss += [self.final_criteria(activations[-1], target)]
            return sum(loss), loss[-1].item()
        elif method == "final":
            loss = self.final_criteria(activations[-1], target) * 1.1 - 0.1
            return loss, loss.item()
        
model = SigmaModel_SimpleCNN(args)
model.to(device)
if args.optimizer == "SGD": 
    forward_optimizer = optim.SGD(model.forward_params, lr=args.lr)
    backward_optimizer = optim.SGD(model.forward_params, lr=args.lr)
elif args.optimizer == "RMSprop": forward_optimizer = optim.RMSprop(model.forward_params, lr=args.lr, weight_decay=0)
elif args.optimizer == "Adam": forward_optimizer = optim.Adam(model.forward_params, lr=args.lr)
criteria = SigmaLoss(args)
    
with torch.no_grad():
    signals = model.reverse(torch.Tensor([0,1,2,3,4,5,6,7,8,9]).long().to(device), return_activations=True)
    
best_val_loss = float('inf')

print("No LN1")
for epoch in tqdm(range(args.epochs)):
    
    train_loss, train_counter = 0, 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        if batch_idx > 100: continue
        if args.method == "SIGMA":
            activations = model(data.to(device), detach_grad=True)
            # signals = model.reverse(torch.Tensor([0,1,2,3,4,5,6,7,8,9]).long().to(device), detach_grad=True)
            loss, loss_item = criteria(activations, signals, target.to(device), method="local")
        elif args.method == "BP":
            activations = model(data.to(device), detach_grad=False)
            loss, loss_item = criteria(activations, signals, target.to(device), method="final")
        forward_optimizer.zero_grad(), backward_optimizer.zero_grad(), loss.backward()
        gradient_centralization(model), forward_optimizer.step(), backward_optimizer.step()
        train_loss += loss_item * len(data)
        train_counter += len(data)

    wandb.log({'train_loss': train_loss / train_counter}, step=epoch)
    
    for batch_idx, (data, target) in enumerate(train_loader):
        if batch_idx > 100: continue
        activations = model(data.to(device), detach_grad=True)
        loss, loss_item = criteria(activations, signals, target.to(device), method="final")
        forward_optimizer.zero_grad(), loss.backward(), forward_optimizer.step()

    # Validation
    val_correct, val_loss, val_counter = 0, 0, 0
    with torch.no_grad():
        for data, target in val_loader:
            val_counter += len(data)
            if args.method == "SIGMA":
                activations = model(data.to(device), detach_grad=True)
                _, loss_item = criteria(activations, signals, target.to(device), method="local")
                
            elif args.method == "BP":
                activations = model(data.to(device), detach_grad=True)
                _, loss_item = criteria(activations, signals, target.to(device), method="final")
            prediction = activations[-1].detach()
            _, predicted = torch.max(prediction, 1)
            val_correct += (predicted == target.to(device)).sum().item()
            val_loss += loss_item * len(data)

    wandb.log({'val_loss': val_loss / val_counter, 
               'val_acc': val_correct / val_counter, 
               }, step=epoch)
    
    print(f"""Epoch {epoch} | train loss {train_loss / train_counter:.4f} | val loss {val_loss / val_counter:.4f} | val acc {100 * val_correct / val_counter:.4f}""")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model)
        torch.save(best_model.state_dict(), f'./saved_models/{run_name}-{time_stamp}.pt')
        

# Eval on Test Set by loading the best model 
model.load_state_dict(torch.load(f'./saved_models/{run_name}-{time_stamp}.pt'))
model.eval()
correct, total = 0, 0
test_loss, test_counter = 0, 0
with torch.no_grad():
    for data, target in test_loader:
        outputs = model(data.to(device), detach_grad=False)
        _, predicted = torch.max(outputs.data, 1)
        correct += (predicted == target.to(device)).sum().item()
        loss, loss_item = criteria(activations, signals, target.to(device), method="final")
        test_loss += loss_item * len(data)
        test_counter += len(data)

wandb.log({'test_loss': test_loss / test_counter,
           'test_acc': 100 * correct / test_counter})

print(f'Epoch: {epoch}, Test Accuracy: {100 * correct / test_counter:.4f}%')

Run name: CIFAR10_SIGMA_conv32-actelu_SGD_1_2023
Files already downloaded and verified
Files already downloaded and verified
No LN1


  1%|▍                                          | 1/100 [00:12<21:10, 12.84s/it]

Epoch 0 | train loss 1.9366 | val loss 1.6378 | val acc 42.0200


  2%|▊                                          | 2/100 [00:25<21:11, 12.97s/it]

Epoch 1 | train loss 1.7337 | val loss 1.5443 | val acc 45.3200


  3%|█▎                                         | 3/100 [00:38<20:55, 12.95s/it]

Epoch 2 | train loss 1.6262 | val loss 1.4805 | val acc 47.2400


  4%|█▋                                         | 4/100 [00:50<19:58, 12.49s/it]

Epoch 3 | train loss 1.5856 | val loss 1.4413 | val acc 49.4800


  5%|██▏                                        | 5/100 [01:01<19:06, 12.06s/it]

Epoch 4 | train loss 1.5380 | val loss 1.4456 | val acc 48.8000


  6%|██▌                                        | 6/100 [01:13<18:36, 11.88s/it]

Epoch 5 | train loss 1.4936 | val loss 1.3480 | val acc 52.8200


  7%|███                                        | 7/100 [01:24<18:05, 11.67s/it]

Epoch 6 | train loss 1.4344 | val loss 1.3465 | val acc 52.8900


  8%|███▍                                       | 8/100 [01:36<17:45, 11.59s/it]

Epoch 7 | train loss 1.3892 | val loss 1.3095 | val acc 54.1000


  9%|███▊                                       | 9/100 [01:48<17:54, 11.81s/it]

Epoch 8 | train loss 1.3561 | val loss 1.2774 | val acc 55.0800


 10%|████▏                                     | 10/100 [01:59<17:19, 11.55s/it]

Epoch 9 | train loss 1.3094 | val loss 1.2583 | val acc 56.0100


 11%|████▌                                     | 11/100 [02:10<16:47, 11.32s/it]

Epoch 10 | train loss 1.2928 | val loss 1.2420 | val acc 56.3300


 12%|█████                                     | 12/100 [02:20<16:22, 11.16s/it]

Epoch 11 | train loss 1.2763 | val loss 1.2368 | val acc 56.9200


 13%|█████▍                                    | 13/100 [02:31<16:03, 11.07s/it]

Epoch 12 | train loss 1.2487 | val loss 1.2276 | val acc 56.8700


 14%|█████▉                                    | 14/100 [02:42<15:40, 10.94s/it]

Epoch 13 | train loss 1.2370 | val loss 1.2140 | val acc 57.6500


 15%|██████▎                                   | 15/100 [02:53<15:23, 10.86s/it]

Epoch 14 | train loss 1.2085 | val loss 1.2009 | val acc 58.4900


 16%|██████▋                                   | 16/100 [03:03<15:05, 10.78s/it]

Epoch 15 | train loss 1.2104 | val loss 1.1961 | val acc 58.8400


 17%|███████▏                                  | 17/100 [03:14<15:06, 10.93s/it]

Epoch 16 | train loss 1.2036 | val loss 1.1840 | val acc 59.2200


 18%|███████▌                                  | 18/100 [03:25<14:51, 10.88s/it]

Epoch 17 | train loss 1.1919 | val loss 1.1761 | val acc 59.2400


 19%|███████▉                                  | 19/100 [03:36<14:37, 10.83s/it]

Epoch 18 | train loss 1.1842 | val loss 1.1956 | val acc 58.0300


 20%|████████▍                                 | 20/100 [03:47<14:28, 10.85s/it]

Epoch 19 | train loss 1.1840 | val loss 1.1870 | val acc 59.0500


 21%|████████▊                                 | 21/100 [03:58<14:16, 10.84s/it]

Epoch 20 | train loss 1.1563 | val loss 1.1764 | val acc 59.1100


 22%|█████████▏                                | 22/100 [04:08<13:57, 10.74s/it]

Epoch 21 | train loss 1.1241 | val loss 1.1715 | val acc 59.4900


 23%|█████████▋                                | 23/100 [04:19<13:43, 10.70s/it]

Epoch 22 | train loss 1.1407 | val loss 1.1714 | val acc 59.5100


 24%|██████████                                | 24/100 [04:29<13:27, 10.63s/it]

Epoch 23 | train loss 1.1092 | val loss 1.1866 | val acc 59.7700


 25%|██████████▌                               | 25/100 [04:40<13:11, 10.55s/it]

Epoch 24 | train loss 1.1405 | val loss 1.1567 | val acc 59.9000


 26%|██████████▉                               | 26/100 [04:50<13:04, 10.59s/it]

Epoch 25 | train loss 1.1196 | val loss 1.1369 | val acc 60.8600


 27%|███████████▎                              | 27/100 [05:01<13:02, 10.72s/it]

Epoch 26 | train loss 1.1132 | val loss 1.1402 | val acc 60.6900


 28%|███████████▊                              | 28/100 [05:12<12:59, 10.83s/it]

Epoch 27 | train loss 1.1023 | val loss 1.1455 | val acc 60.6900


 29%|████████████▏                             | 29/100 [05:24<13:03, 11.04s/it]

Epoch 28 | train loss 1.1138 | val loss 1.1466 | val acc 60.7900


 30%|████████████▌                             | 30/100 [05:35<12:53, 11.05s/it]

Epoch 29 | train loss 1.1009 | val loss 1.1361 | val acc 60.7600


 31%|█████████████                             | 31/100 [05:45<12:28, 10.84s/it]

Epoch 30 | train loss 1.0905 | val loss 1.1459 | val acc 60.5800


 32%|█████████████▍                            | 32/100 [05:56<12:06, 10.69s/it]

Epoch 31 | train loss 1.1032 | val loss 1.1353 | val acc 60.9300


 33%|█████████████▊                            | 33/100 [06:07<12:04, 10.81s/it]

Epoch 32 | train loss 1.1115 | val loss 1.1533 | val acc 60.7000


 33%|█████████████▊                            | 33/100 [06:11<12:34, 11.26s/it]


KeyboardInterrupt: 

In [None]:
Epoch 0 | train loss 1.9373 | val loss 1.6394 | val acc 41.9600
  2%|▊                                          | 2/100 [00:24<19:51, 12.16s/it]
Epoch 1 | train loss 1.7356 | val loss 1.5472 | val acc 45.1800
  3%|█▎                                         | 3/100 [00:36<19:46, 12.23s/it]
Epoch 2 | train loss 1.6277 | val loss 1.4846 | val acc 47.2700
  4%|█▋                                         | 4/100 [00:48<19:26, 12.15s/it]
Epoch 3 | train loss 1.5847 | val loss 1.4396 | val acc 49.5900
  5%|██▏                                        | 5/100 [01:00<19:23, 12.25s/it]
Epoch 4 | train loss 1.5357 | val loss 1.4488 | val acc 48.9100
  6%|██▌                                        | 6/100 [01:13<19:19, 12.34s/it]
Epoch 5 | train loss 1.4928 | val loss 1.3468 | val acc 52.9600
  7%|███                                        | 7/100 [01:26<19:26, 12.54s/it]
Epoch 6 | train loss 1.4379 | val loss 1.3484 | val acc 52.8400
  8%|███▍                                       | 8/100 [01:39<19:33, 12.75s/it]
Epoch 7 | train loss 1.3875 | val loss 1.3066 | val acc 54.5600
  9%|███▊                                       | 9/100 [01:51<19:06, 12.60s/it]
Epoch 8 | train loss 1.3699 | val loss 1.2766 | val acc 55.5900
 10%|████▏                                     | 10/100 [02:05<19:19, 12.88s/it]
Epoch 9 | train loss 1.3236 | val loss 1.2622 | val acc 55.9400
 11%|████▌                                     | 11/100 [02:19<19:49, 13.36s/it]
Epoch 10 | train loss 1.3133 | val loss 1.2524 | val acc 56.2600
 12%|█████                                     | 12/100 [02:34<20:07, 13.72s/it]
Epoch 11 | train loss 1.2969 | val loss 1.2430 | val acc 57.1900
 13%|█████▍                                    | 13/100 [02:48<20:07, 13.88s/it]
Epoch 12 | train loss 1.2637 | val loss 1.2338 | val acc 56.9000
 14%|█████▉                                    | 14/100 [03:05<21:02, 14.68s/it]
Epoch 13 | train loss 1.2559 | val loss 1.2223 | val acc 58.0000
 15%|██████▎                                   | 15/100 [03:21<21:37, 15.27s/it]
Epoch 14 | train loss 1.2275 | val loss 1.2109 | val acc 58.2600
 16%|██████▋                                   | 16/100 [03:37<21:25, 15.30s/it]
Epoch 15 | train loss 1.2327 | val loss 1.2032 | val acc 58.4900
 17%|███████▏                                  | 17/100 [03:53<21:35, 15.60s/it]
Epoch 16 | train loss 1.2229 | val loss 1.1895 | val acc 59.2200
 18%|███████▌                                  | 18/100 [04:09<21:31, 15.75s/it]
Epoch 17 | train loss 1.2117 | val loss 1.1838 | val acc 59.1000
 19%|███████▉                                  | 19/100 [04:24<20:52, 15.46s/it]
Epoch 18 | train loss 1.2045 | val loss 1.1971 | val acc 59.0200
 20%|████████▍                                 | 20/100 [04:39<20:19, 15.24s/it]
Epoch 19 | train loss 1.1952 | val loss 1.1982 | val acc 59.2300
 21%|████████▊                                 | 21/100 [04:55<20:20, 15.45s/it]
Epoch 20 | train loss 1.1688 | val loss 1.1887 | val acc 59.2200

In [None]:
plt.plot(S1)
plt.plot(S2)

In [None]:
W12 = W1@W2
plt.imshow(W12), print(W12.min(), W12.max()), 

In [None]:
# Validation
val_correct, val_loss, val_counter = 0, 0, 0
with torch.no_grad():
    for data, target in val_loader:
        val_counter += len(data)
        if args.method == "SIGMA":
            activations = model(data.to(device), detach_grad=True)
            _, loss_item = criteria(activations, signals, target.to(device), method="local")
        elif args.method == "BP":
            activations = model(data.to(device), detach_grad=True)
            _, loss_item = criteria(activations, signals, target.to(device), method="final")
        prediction = activations[-1].detach()
        _, predicted = torch.max(prediction, 1)
        val_correct += (predicted == target.to(device)).sum().item()
        val_loss += loss_item * len(data)

wandb.log({'val_loss': val_loss / val_counter, 
           'val_acc': val_correct / val_counter, 
           }, step=epoch)
s
print(f"""Epoch {epoch} | train loss {train_loss / train_counter:.4f} | val loss {val_loss / val_counter:.4f} | val acc {100 * val_correct / val_counter:.4f}""")

In [None]:
train_loss