## Define hyperparameters

In [1]:

# Dataset options
DATASET = 'CIFAR100'  # Options: 'CIFAR10' or 'CIFAR100'
# DATASET = 'CIFAR10'

# Number of classes options
NUM_CLASSES = 20     # Set the number of classes
# NUM_CLASSES = 10   # For example, if using CIFAR-10, set to 10

# Number of attention heads options
NUM_HEADS = 8        # Options: 8, 4, 2, etc.
# NUM_HEADS = 4
# NUM_HEADS = 2

In [2]:
def validate_hyperparameters(dataset_name, num_classes, num_heads):
    """
    Validates the hyperparameters for dataset, number of classes, and number of attention heads.

    Args:
        dataset_name (str): The name of the dataset ('CIFAR10' or 'CIFAR100').
        num_classes (int): The number of classes.
        num_heads (int): The number of attention heads.

    Raises:
        ValueError: If any hyperparameter is invalid.
    """
    valid_datasets = ['CIFAR10', 'CIFAR100']
    if dataset_name not in valid_datasets:
        raise ValueError(f"Invalid DATASET value: {dataset_name}. Choose from {valid_datasets}.")

    if dataset_name == 'CIFAR10' and num_classes != 10:
        raise ValueError(f"For {dataset_name}, NUM_CLASSES must be 10. Current value: {num_classes}.")
    elif dataset_name == 'CIFAR100' and num_classes not in [20, 100]:
        raise ValueError(f"For {dataset_name}, NUM_CLASSES must be 20 or 100. Current value: {num_classes}.")

    valid_heads = [8, 4, 2]
    if num_heads not in valid_heads:
        raise ValueError(f"Invalid NUM_HEADS value: {num_heads}. Choose from {valid_heads}.")

In [3]:

# Validate hyperparameters
validate_hyperparameters(DATASET, NUM_CLASSES, NUM_HEADS)


In [4]:
import torchvision
import torch
import torch.nn as nn



## Initial Setup

In [6]:
# -*- coding: utf-8 -*-
'''

Train CIFAR10 with PyTorch and Vision Transformers!
written by @kentaroy47, @arutema47
source : https://github.com/kentaroy47/vision-transformers-cifar10

'''

from __future__ import print_function

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import numpy as np

import torchvision
import torchvision.transforms as transforms

import os
import argparse
import pandas as pd
import csv
import time




### Helper functions

#### Saving loading

In [7]:
import torch
import os

def save_model_state(model, epoch, loss, accuracy, checkpoint_dir='checkpoints', log_file='training_log.txt'):
    # Create checkpoint directory if it doesn't exist
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    # Save model state
    model_checkpoint_path = os.path.join(checkpoint_dir, f'model_epoch_{epoch}.pth')
    torch.save(model.state_dict(), model_checkpoint_path)
    print(f'Model state saved at epoch {epoch}')
    
    # Log accuracy and loss
    log_file_path = os.path.join(checkpoint_dir, log_file)
    
    with open(log_file_path, 'a') as f:
        f.write(f'Epoch {epoch}: Accuracy = {accuracy:.4f}, Loss = {loss:.4f}\n')
    
    print(f'Logged epoch {epoch} - Accuracy: {accuracy:.4f}, Loss: {loss:.4f}')

    
import torch
import os

def load_model_state(model, epoch = 90, checkpoint_dir='checkpoints'):
    model_checkpoint_path = os.path.join(checkpoint_dir, f'model_epoch_{epoch}.pth')
    model.load_state_dict(torch.load(model_checkpoint_path))   
    print(f'Model state loaded from epoch {epoch}')
    return epoch



In [8]:
import os

directory = 'checkpoints'

if os.path.isdir(directory):
    print("Directory exists")
    items = os.listdir(directory)
    for item in items:
        print(item)
else:
    print("Directory does not exist")
    
#     if os.path.isdir(directory):
    


Directory does not exist


#### Remapping labels function

In [9]:
def remap_labels(labels, num_classes_old, num_classes_new):
    """
    Adjusts the labels from an old class structure to a new one.

    Args:
        labels (torch.Tensor or list): Original labels to be adjusted.
        num_classes_old (int): The number of classes in the original dataset.
        num_classes_new (int): The number of classes in the new dataset.

    Returns:
        torch.Tensor or list: The labels adjusted to the new class structure.
    """
    # Check that the number of old classes is divisible by the number of new classes
    assert num_classes_old % num_classes_new == 0, "The number of old classes must be divisible by the number of new classes."

    # Compute the factor to convert old labels to new labels
    factor = num_classes_old // num_classes_new

    # Remap each label
    if isinstance(labels, torch.Tensor):
        # If labels are a tensor, apply the remapping to each label and return a tensor
        remapped_labels = torch.tensor([label.item() // factor for label in labels])
    else:
        # If labels are a list, apply the remapping to each label and return a list
        remapped_labels = [label // factor for label in labels]
    
    return remapped_labels


In [11]:
from torch.utils.data import Dataset
class CustomDataset(Dataset):
    def __init__(self, dataset, num_classes_old, num_classes_new):
        self.dataset = dataset
        self.num_classes_old = num_classes_old
        self.num_classes_new = num_classes_new

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        image, label = self.dataset[index]
        remapped_label = remap_labels(torch.tensor([label]), self.num_classes_old, self.num_classes_new).item()
        return image, remapped_label

In [12]:
run_number = 1
base_dir = "results/runs"

# Define the new run directory

os.makedirs("results", exist_ok=True)
os.makedirs("results/runs", exist_ok=True)

In [13]:
torch.manual_seed(42)

<torch._C.Generator at 0x7fbd488e84f0>

In [14]:
# # setup for a read only personal access token
# # note : token expires 19 aug 2025
# token = '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'
# token_user = 'Asterisk07'
# repo_host = 'Asterisk07'
# repo_name = 'BTP-Transformer-explainability'

# url = f'https://{token_user}:{token}@github.com/{repo_host}/{repo_name}/'
# !git clone {url}

# !mv {repo_name}/models .
# !rm -rf BTP-Transformer-explainability # delete a file

In [15]:
!ls

results


In [16]:
# !rm -rf models

In [17]:
# !npm install -g github-files-fetcher

In [18]:
# !fetcher --url=https://github.com/kentaroy47/vision-transformers-cifar10/tree/main/models
# !fetcher --url=https://https://github.com/Asterisk07/BTP-Transformer-explainability/main/models


In [19]:
2

2

In [20]:

import os

# Check if 'utils.py' exists in the current directory
if os.path.exists('utils.py'):
    print("utils.py exists in the current directory.")
else:
    print("utils.py does not exist in the current directory.")
    !wget https://raw.githubusercontent.com/kentaroy47/vision-transformers-cifar10/main/utils.py
    print("utils.py fetched")



utils.py does not exist in the current directory.
--2024-09-07 09:35:27--  https://raw.githubusercontent.com/kentaroy47/vision-transformers-cifar10/main/utils.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3501 (3.4K) [text/plain]
Saving to: 'utils.py'


2024-09-07 09:35:27 (38.7 MB/s) - 'utils.py' saved [3501/3501]

utils.py fetched


In [21]:
from utils import progress_bar

stty: 'standard input': Inappropriate ioctl for device


In [22]:
progress_bar

<function utils.progress_bar(current, total, msg=None)>

In [23]:


# from randomaug import RandAugment
from torchvision.transforms import RandAugment



In [24]:
!pip install einops

Collecting einops
  Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)
Downloading einops-0.8.0-py3-none-any.whl (43 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m718.5 kB/s[0m eta [36m0:00:00[0m[36m0:00:01[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.8.0


In [25]:
# from models import *
# from models.vit import ViT
# from models.convmixer import ConvMixer

In [26]:

import json

In [27]:

qkv_titles = ['q','k','v']

In [28]:
# https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py
# VIT.py
import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import numpy as np
# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x,save_flag=False, run_dir = None,img_idx = None):
        out =  self.net(x)
        if(save_flag==True):
                file_path = os.path.join(run_dir, 'ff_out.npy')
                # np.save(file_path, out)
                np.save(file_path, out[img_idx].detach().cpu().numpy())
        return out

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()

        inner_dim = dim_head *  heads
        # print("attention : dim = ", dim, "| inner_dim = ",inner_dim,"| dim_head = ", dim_head, "| heads = ",heads  )
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x,save_flag=False, run_dir = None,img_idx = None):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)



        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        if(save_flag==True):

                # Convert each tensor in `qkv` to a numpy array and save it
#                 qkv=attention.to_qkv
                for i, tensor in enumerate((q,k,v)):
                    np_array = tensor[img_idx].detach().cpu().numpy()  # Convert to numpy
                    # np.save(f'qkv_{i}.npy', np_array)  # Save each as a .npy file
                    file_path = os.path.join(run_dir, f'{qkv_titles[i]}.npy')
                    np.save(file_path, np_array)
                file_path = os.path.join(run_dir, 'att_out')
                np.save(file_path, out[img_idx].detach().cpu().numpy())
                file_path = os.path.join(run_dir, 'att_score')
                np.save(file_path,attn[img_idx].detach().cpu().numpy())
        out = rearrange(out, 'b h n d -> b n (h d)')
        # return self.to_out(out),q,k,v
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        # print("transformer : dim = ", dim, "| dim_head = ", dim_head, "| heads = ",heads  )

        self.layers = nn.ModuleList([])
#         self.saved_values = {'logits': [], 'queries': [], 'keys': [], 'values': []}  # To store the values
        # self.saved_values = list()  # To store th
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
            ]))

    def forward(self, x,save_flag=False, run_dir = None, img_idx = None):
        for i, (attn, ff) in enumerate(self.layers):
            # Unpack the output from the Attention layer
            #
            # print("passed trans direcetory ", run_dir, " and saving ",save_flag)
            if save_flag:
              layer_dir = os.path.join(run_dir,  f"layer {i:02}")
              os.makedirs(layer_dir, exist_ok=True)
              # print("passed trans layer direcetory ", layer_dir)
            else:
              layer_dir = None
            attn_out = attn(x,save_flag=save_flag, run_dir = layer_dir, img_idx = img_idx)



            # Save the query, key, value, and logits (output) for this layer
            # self.saved_values.append(q.cpu().detach().numpy())
            # self.saved_values.append(k.cpu().detach().numpy())
            # self.saved_values.append(v.cpu().detach().numpy())

            # Combine the attention output with the original x
            x = attn_out + x
            # self.saved_values.append(x.cpu().detach().numpy())  # Save logits
            # print("i : ",i)
            # Apply the feedforward network
#             x = ff(x) + x

            x = ff(x,save_flag=save_flag, run_dir = layer_dir, img_idx = img_idx) + x

        return x


class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        # print("vit : dim = ", dim, "| dim_head = ", dim_head, "| heads = ",heads , " | mlp = ",mlp_dim )

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img, save_flag=False, run_dir = None,img_idx = None):
        # if (save_flag):
          # print("\n\treached here 3")
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x,save_flag, run_dir, img_idx)
#         if(save_flag==True):
#                 # Convert each tensor in `qkv` to a numpy array and save it
#                 qkv=attention.to_qkv
#                 for i, tensor in enumerate(qkv):
#                     np_array = tensor.detach().cpu().numpy()  # Convert to numpy
#                     np.save(f'qkv_{i}.npy', np_array)  # Save each as a .npy file


        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

In [29]:
import argparse
import sys

# Define your arguments here
def parse_args():
    # parsers
    parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
    parser.add_argument('--lr', default=1e-4, type=float, help='learning rate') # resnets.. 1e-3, Vit..1e-4
    parser.add_argument('--opt', default="adam")
    parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
    parser.add_argument('--noaug', action='store_false', help='disable use randomaug')
    parser.add_argument('--noamp', action='store_true', help='disable mixed precision training. for older pytorch versions')
    parser.add_argument('--nowandb', action='store_true', help='disable wandb')
    parser.add_argument('--mixup', action='store_true', help='add mixup augumentations')
    parser.add_argument('--net', default='vit')
    parser.add_argument('--dp', action='store_true', help='use data parallel')
    parser.add_argument('--bs', default='512')
    parser.add_argument('--size', default="32")
    parser.add_argument('--n_epochs', type=int, default='200')
    parser.add_argument('--patch', default='4', type=int, help="patch for ViT")
    parser.add_argument('--dimhead', default="512", type=int)
    parser.add_argument('--convkernel', default='8', type=int, help="parameter for convmixer")

    return parser.parse_args()




In [30]:
command = 'python train_cifar10.py --n_epochs 500 --lr 0.0005'
command.split()[1:]

['train_cifar10.py', '--n_epochs', '500', '--lr', '0.0005']

In [31]:
# Simulate command-line arguments
# sys.argv = ['your_script.py', '--lr', '0.2', '--opt', 'adam', '--net', 'vit', '--bs', '64','--dimhead','256']
sys.argv = command.split()[1:]

args = parse_args()



In [32]:
# !pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu118 --upgrade --force-reinstall

In [33]:
# (2.0.1+cu117)
# Requirement already satisfied: torchvision in /opt/conda/lib/python3.10/site-packages (0.15.2+cu117)

In [34]:
# !pip show torchvision


In [35]:
2

2

In [36]:
# !pip show torch
# #

In [37]:
import torchvision
torchvision.__version__

'0.19.0'

In [38]:
import torch
torch.__version__

'2.4.0'

In [39]:
!pip install wandb



In [40]:

# take in args
usewandb = ~args.nowandb
if usewandb:
    import wandb
    watermark = "{}_lr{}".format(args.net, args.lr)
    wandb.init(project="cifar10-challange",
            name=watermark)
    wandb.config.update(args)

bs = int(args.bs)
imsize = int(args.size)

use_amp = not args.noamp
aug = args.noaug

device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

# Data
print('==> Preparing data..')
if args.net=="vit_timm":
    size = 384
else:
    size = imsize

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.Resize(size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Resize(size),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Add RandAugment with N, M(hyperparameter)
if aug:
    N = 2; M = 14;
    transform_train.transforms.insert(0, RandAugment(N, M))

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


==> Preparing data..


In [41]:
# use only this token :
# f439c9e9cdf4ff7e3d47e80d4588628783d8bafe #aster

In [42]:
NUM_WORKERS = 4

# Prepare dataset

# trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)

# testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)


# Prepare dataset based on hyperparameter
if DATASET == 'CIFAR10':
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
elif DATASET == 'CIFAR100':
    if NUM_CLASSES%20 != 0 :
        raise ValueError("Invalid value of NUM_CLASSES specified. Choose 20 or 100")
    trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
    testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
else:
    raise ValueError("Invalid dataset specified. Choose 'CIFAR10' or 'CIFAR100'.")

    
trainloader = torch.utils.data.DataLoader(trainset, batch_size=bs, shuffle=True, num_workers=NUM_WORKERS)
testloader = torch.utils.data.DataLoader(testset, batch_size=bs, shuffle=False, num_workers=NUM_WORKERS)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:02<00:00, 63020183.91it/s]


Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified


### Remapping labels if needed

In [43]:
if NUM_CLASSES == 20:
    # Create custom dataset class to remap labels


    # Create custom datasets with remapped labels
    trainset = CustomDataset(trainset, num_classes_old=100, num_classes_new=NUM_CLASSES)
    testset = CustomDataset(testset, num_classes_old=100, num_classes_new=NUM_CLASSES)


In [44]:
trainloader = torch.utils.data.DataLoader(trainset, batch_size=bs, shuffle=True, num_workers=NUM_WORKERS)
testloader = torch.utils.data.DataLoader(testset, batch_size=bs, shuffle=False, num_workers=NUM_WORKERS)

In [45]:
# For Multi-GPU
if 'cuda' in device:
    print(device)
    if args.dp:
        print("using data parallel")
        net = torch.nn.DataParallel(net) # make parallel
        cudnn.benchmark = True


cuda


In [46]:
!rm -rf results

In [47]:

def get_vit():
    return ViT(
    image_size = size,
    patch_size = args.patch,
    # num_classes = 10,
    num_classes = NUM_CLASSES,
    dim = int(args.dimhead),
    depth = 6,
    # heads = 8,
    heads = NUM_HEADS,
    # mlp_dim = 512,
    mlp_dim = 256,
    dropout = 0.1,
    emb_dropout = 0.1,
    )

In [48]:


classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# Model factory..
print('==> Building model..')
# net = VGG('VGG19')
if args.net=="vit":
    # ViT for cifar10
    net = get_vit()



if args.resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/{}-ckpt.t7'.format(args.net))
    net.load_state_dict(checkpoint['net'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

==> Building model..


In [49]:
from tqdm import tqdm

In [50]:

len(trainloader)

98

In [51]:

# trainloader[0]

In [52]:
MAX_EPOCHS = 90

In [53]:
import numpy as np

# Loss is CE
criterion = nn.CrossEntropyLoss()

torch.manual_seed(42)
net = get_vit()

if args.opt == "adam":
    optimizer = optim.Adam(net.parameters(), lr=args.lr)
elif args.opt == "sgd":
    optimizer = optim.SGD(net.parameters(), lr=args.lr)

# use cosine scheduling
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.n_epochs)

##### Training
scaler = torch.amp.GradScaler('cuda',enabled=use_amp)
def train(epoch,save_flag, run_dir = None, img_idx = None):
    
    net.train()
    train_loss = 0
    correct = 0
    total = 0

    # img_factor = len(trainloader) // img_save_count
    # run_dir = os.path.join(run_dir,  {epoch:02}")

    # data_save=list()
    # main_list=list()
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        # Train with amp
        with torch.amp.autocast('cuda',enabled=use_amp):
            # if(save_flag==True and batch_idx%img_factor==0):

            if(save_flag==True and batch_idx==0):
                batch_dir = os.path.join(run_dir, f'batch {batch_idx}')
                os.makedirs(batch_dir, exist_ok=True)
                # np.save(file_path, np_array)
                # print("\n\tpassed ",batch_dir, type(batch_dir))

                outputs = net(inputs, True, batch_dir, img_idx)
                # outputs = net(inputs, False, 12)
                #here can pass in net(inputs,image_saveflag=1) so it will save the image to disk by making changes in model.
            else:
                outputs = net(inputs)
            loss = criterion(outputs, targets)



        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    # data_save.append(net.transformer.saved_values)

#         progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
#             % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
#     return train_loss/(batch_idx+1),net.transformer.saved_values
    return train_loss/(batch_idx+1)
##### Validation
def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

#             progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
#                 % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    acc = 100.*correct/total
#     if acc > best_acc:
#         print('Saving..')
#         state = {"model": net.state_dict(),
#               "optimizer": optimizer.state_dict(),
#               "scaler": scaler.state_dict()}
#         if not os.path.isdir('checkpoint'):
#             os.mkdir('checkpoint')
#         torch.save(state, './checkpoint/'+args.net+'-{}-ckpt.t7'.format(args.patch))
#         best_acc = acc

#     os.makedirs("log", exist_ok=True)

    os.makedirs("results", exist_ok=True)
    os.makedirs("results/log", exist_ok=True)
    content = f'Epoch {epoch}, lr: {optimizer.param_groups[0]["lr"]:.7f}, val loss: {test_loss:.5f}, acc: {(acc):.5f}'
    print(content)
#     with open(f'log/log_{args.net}_patch{args.patch}.txt', 'a') as appender:
#         appender.write(content + "\n")
    return test_loss, acc

list_loss = []
list_acc = []

if usewandb:
    wandb.watch(net)

    

# save_epochs-=1
batch_size = int(args.bs)
# max_epochs = args.n_epochs




if device == 'cuda':
  net.cuda()
main_list=list()
data_save=list()
n_param=5

run_dir = os.path.join(base_dir, f"run {run_number:02}")
os.makedirs(run_dir, exist_ok=True)
print("Run number ",run_number)






import shutil
from IPython.display import FileLink

# Specify the directory you want to compress
directory_name = run_dir
zip_filename = f'{run_dir}.zip'




run_number += 1

max_epochs = MAX_EPOCHS

# take_epoch_factor = 
img_save_count = 50 #IMP

img_idx = torch.randperm(batch_size)[:img_save_count]
img_idx= img_idx.sort()[0]
print("chosen images are of batch 0 and numbers : ",[x.item() for x in list(img_idx)])
file_path = os.path.join(run_dir, 'img_idx.npy')
np.save(file_path, img_idx.detach().cpu().numpy())



# save_epochs =  #IMP
# epoch_factor = max_epochs  // save_epochs #IMP

epoch_factor = 10 #IMP


print(f"Saving results every {epoch_factor} epochs ")




print("Training started")
for i in tqdm(range(start_epoch, max_epochs), desc="Training"):
    epoch = i+1
    start = time.time()
    
#     if(epoch%epoch_factor==0 or epoch == 1 or epoch == max_epochs):
    if False:
      # Define the new run directory
        
        epoch_dir = os.path.join(run_dir, f"epoch {epoch:02}")
        # print("\n\tpassed into trainloss",run_dir)
        trainloss = train(epoch,True, run_dir = epoch_dir, img_idx = img_idx)
        print("saved epoch")
        # Compress the directory into a zip file, overwriting if it already exists
        shutil.make_archive(zip_filename.replace('.zip', ''), 'zip', directory_name)

#         print(f"Directory '{directory_name}' has been zipped as '{zip_filename}'.")
        print("Click here to download run  : ")
        display(FileLink(zip_filename))
    
    else:
        trainloss = train(epoch,False)



#     if(epoch%n_param!=0 or epoch==0):
#         data_save.append(saved_data)
#     else:
#         data_save.append(saved_data)
#         main_list.append(data_save)
#         data_save=list()
#     val_loss, acc = test(epoch)

    scheduler.step() # step cosine scheduling
    if (epoch%epoch_factor==0 or epoch == 1 or epoch == max_epochs and MODEL_SAVE_FLAG):
        
        
        val_loss, acc = test(epoch)
        save_model_state(net, epoch , val_loss, acc)
        

#     list_loss.append(val_loss)
#     list_acc.append(acc)

    # Log training..
#     if usewandb:
#         wandb.log({'epoch': epoch, 'train_loss': trainloss, 'val_loss': val_loss, "val_acc": acc, "lr": optimizer.param_groups[0]["lr"],
#         "epoch_time": time.time()-start})

#     # Write out csv..
#     with open(f'log/log_{args.net}_patch{args.patch}.csv', 'w') as f:
#         writer = csv.writer(f, lineterminator='\n')
#         writer.writerow(list_loss)
#         writer.writerow(list_acc)
# #     print(list_loss)
    print()
    
# writeout wandb
if usewandb:
    wandb.save("wandb_{}.h5".format(args.net))

import shutil
from IPython.display import FileLink
zip_filename = 'checkpoints.zip'
directory_name = 'checkpoints'
shutil.make_archive(zip_filename.replace('.zip', ''), 'zip', directory_name)

#         print(f"Directory '{directory_name}' has been zipped as '{zip_filename}'.")
print("Click here to download checkpoints  : ")
display(FileLink(zip_filename))

Run number  1
chosen images are of batch 0 and numbers :  [7, 31, 33, 38, 45, 57, 58, 59, 60, 98, 118, 126, 131, 135, 139, 141, 142, 143, 147, 155, 162, 184, 209, 219, 233, 245, 252, 280, 286, 296, 310, 327, 349, 351, 357, 365, 368, 399, 411, 422, 424, 425, 431, 442, 452, 457, 463, 481, 482, 502]
Saving results every 10 epochs 
Training started


Training:   1%|          | 1/90 [00:32<47:32, 32.05s/it]

Epoch 1, lr: 0.0005000, val loss: 56.42817, acc: 13.21000
Model state saved at epoch 1
Logged epoch 1 - Accuracy: 13.2100, Loss: 56.4282



Training:   2%|▏         | 2/90 [00:59<43:05, 29.38s/it]




Training:   2%|▏         | 2/90 [01:14<54:33, 37.20s/it]


KeyboardInterrupt: 

In [None]:
import os

# Specify the path to your file
file_path = 'checkpoints.zip'

# Get the size of the file in bytes
file_size = os.path.getsize(file_path)

print(f"The size of the file is : {(file_size/(2**20)):.0f} MB")


In [None]:
raise ZeroDivisionError

In [None]:
!cd data
!ls


In [None]:
!ls

In [None]:
# !rm -rf results
# !rm -rf log


In [None]:
# 

In [None]:
# img_idx.detach().numpy()

In [None]:
# import shutil
# from IPython.display import FileLink

# # Specify the directory you want to compress
# directory_name = 'log'
# zip_filename = 'log.zip'

# # Compress the directory into a zip file, overwriting if it already exists
# shutil.make_archive(zip_filename.replace('.zip', ''), 'zip', directory_name)

# # Optionally generate and display a download link
# print(f"Directory '{directory_name}' has been zipped as '{zip_filename}'.")
# FileLink(zip_filename)


In [None]:

net

In [None]:

/content/results/runs/run 04/epoch 00/batch 0/layer 01

In [None]:
!ls

In [None]:
cd 

In [None]:

import os

# Define the directory path you want to check
directory_path = r'results/runs/run 03/epoch 00/batch 0/layer 01/'

# Check if the directory exists
if os.path.isdir(directory_path):
    print(f"The directory '{directory_path}' exists.")
else:
    print(f"The directory '{directory_path}' does not exist.")


In [None]:
file_path = r'results/runs/run 03/epoch 00/batch 0/layer 01/01_attention_out.npy'

# Load the NumPy array from the file
data = np.load(file_path)

In [None]:
data.shape

In [None]:

data.shape
# shape : batch x head x X x Y