In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!git clone https://github.com/thefurorjuror/MARS.git
!cp /content/drive/MyDrive/MARS_UCF101_16f.pth /content/MARS_UCF101_16f.pth
!cp -a /content/MARS/. /content/

In [None]:
!pip install -qqq wandb
import wandb
import os
import numpy as np
from sklearn.metrics import confusion_matrix
from dataset.dataset import *
from torch.utils.data import Dataset, DataLoader
import getpass
import socket
import numpy as np
from dataset.preprocess_data import *
from PIL import Image, ImageFilter
import argparse
import torch
from torch import nn
from torch import optim
from torch.optim import lr_scheduler
from models.model import generate_model
from torch.autograd import Variable
import time
import sys
import pdb
import argparse
import shutil
from tqdm import tqdm
import pickle
from tqdm.notebook import tqdm
from pathlib import Path
import pandas as pd
import torch.nn.functional as F
from torch.utils.data import Dataset
from __future__ import division
import csv

In [None]:
def init_wandb(model, args=None, pytorch=True) -> None:
    """
    Initialize project on Weights & Biases

    Args:
        model (Model): Model for Training
        args (dict,optional): dict with wandb config. Defaults to None.
        wandb_api_key : add your api key
        wandb_name : add a unique descriptive name for the run
        project : name of wandb project
        Sample args : args = {'wandb_api_key': '','wandb_name' : 'test', 'project' : 'test_project'}
        pytorch : whether model is in pytorch or tensorflow
    """
    wandb.login(key=args['wandb_api_key'])
    wandb.init(
        name=args['wandb_name'],
        project=args['project'],
        resume=True,
        dir="./"
    )
    if pytorch:
      	wandb.watch(model, log="all")


In [None]:
def parse_opts():
    parser = argparse.ArgumentParser()
    # Datasets 
    parser.add_argument(
        '--frame_dir',
        default='dataset/HMDB51/',
        type=str,
        help='path of jpg files')
    parser.add_argument(
        '--annotation_path',
        default='dataset/HMDB51_labels',
        type=str,
        help='label paths')
    parser.add_argument(
        '--dataset',
        default='HMDB51',
        type=str,
        help='(HMDB51, UCF101, Kinectics)')
    parser.add_argument(
        '--split',
        default=1,
        type=str,
        help='(for HMDB51 and UCF101)')
    parser.add_argument(
        '--modality',
        default='RGB',
        type=str,
        help='(RGB, Flow)')
    parser.add_argument(
        '--input_channels',
        default=3,
        type=int,
        help='(3, 2)')
    parser.add_argument(
        '--n_classes',
        default=101,
        type=int,
        help='Number of classes (activitynet: 200, kinetics: 400, ucf101: 101, hmdb51: 51)')
    parser.add_argument(
        '--n_finetune_classes',
        default=600,
        type=int,
        help=
        'Number of classes for fine-tuning. n_classes is set to the number when pretraining.')
    parser.add_argument(
        '--only_RGB', 
        action='store_true', 
        help='Extracted only RGB frames')
    parser.set_defaults(only_RGB = False)
    
    
    # Model parameters
    parser.add_argument(
        '--output_layers',
        action='append',
        help='layer to output on forward pass')
    parser.set_defaults(output_layers=[])
    parser.add_argument(
        '--model',
        default='resnext',
        type=str,
        help='Model base architecture')
    parser.add_argument(
        '--model_depth',
        default=101,
        type=int,
        help='Number of layers in model')
    parser.add_argument(
        '--resnet_shortcut',
        default='B',
        type=str,
        help='Shortcut type of resnet (A | B)')
    parser.add_argument(
        '--resnext_cardinality',
        default=32,
        type=int,
        help='ResNeXt cardinality')
    parser.add_argument(
        '--ft_begin_index',
        default=4,
        type=int,
        help='Begin block index of fine-tuning')
    parser.add_argument(
        '--sample_size',
        default=224,
        type=int,
        help='Height and width of inputs')
    parser.add_argument(
        '--sample_duration',
        default=16,
        type=int,
        help='Temporal duration of inputs')
    parser.add_argument(
        '--training', 
        action='store_true', 
        help='training/testing')
    parser.set_defaults(training=True)
    parser.add_argument(
        '--freeze_BN', 
        action='store_true', 
        help='freeze_BN/testing')
    parser.set_defaults(freeze_BN=False)
    parser.add_argument(
        '--batch_size', 
        default=1, 
        type=int, 
        help='Batch Size')
    parser.add_argument(
        '--n_workers', 
        default=4, 
        type=int, 
        help='Number of workers for dataloader')

    # optimizer parameters
    parser.add_argument(
        '--learning_rate',
        default=0.1,
        type=float,
        help='Initial learning rate (divided by 10 while training by lr scheduler)')
    parser.add_argument(
        '--momentum', 
        default=0.9, 
        type=float, 
        help='Momentum')
    parser.add_argument(
        '--dampening', 
        default=0.9, 
        type=float, 
        help='dampening of SGD')
    parser.add_argument(
        '--weight_decay', 
        default=1e-3, 
        type=float, 
        help='Weight Decay')
    parser.add_argument(
        '--nesterov', 
        action='store_true', 
        help='Nesterov momentum')
    parser.set_defaults(nesterov=False)
    parser.add_argument(
        '--optimizer',
        default='sgd',
        type=str,
        help='Currently only support SGD')
    parser.add_argument(
        '--lr_patience',
        default=10,
        type=int,
        help='Patience of LR scheduler. See documentation of ReduceLROnPlateau.')
    parser.add_argument(
        '--MARS_alpha', 
        default=50, 
        type=float, 
        help='Weight of Flow augemented MSE loss')
    parser.add_argument(
        '--n_epochs',
        default=400,
        type=int,
        help='Number of total epochs to run')
    parser.add_argument(
        '--begin_epoch',
        default=1,
        type=int,
        help='Training begins at this epoch. Previous trained model indicated by resume_path is loaded.')

    # options for logging
    parser.add_argument(
        '--result_path',
        default='',
        type=str,
        help='result_path')
    parser.add_argument(
        '--MARS', 
        action='store_true', 
        help='test MARS')
    parser.set_defaults(MARS=False)    
    parser.add_argument(
        '--pretrain_path', 
        default='/content/MARS_UCF101_16f.pth', 
        type=str, 
        help='Pretrained model (.pth)')
    parser.add_argument(
        '--MARS_pretrain_path', 
        default='', 
        type=str, 
        help='Pretrained model (.pth)')
    parser.add_argument(
        '--MARS_resume_path', 
        default='', 
        type=str, 
        help='MARS resume model (.pth)')
    parser.add_argument(
        '--resume_path1',
        default='',
        type=str,
        help='Save data (.pth) of previous training')
    parser.add_argument(
        '--resume_path2',
        default='',
        type=str,
        help='Save data (.pth) of previous training')
    parser.add_argument(
        '--resume_path3',
        default='',
        type=str,
        help='Save data (.pth) of previous training')
    parser.add_argument(
        '--log',
        default=1,
        type=int,
        help='Log training and validation')
    parser.add_argument(
        '--checkpoint',
        default=2,
        type=int,
        help='Trained model is saved at every this epochs.')
    
    parser.add_argument(
        '--manual_seed', default=1, type=int, help='Manually set random seed')
    parser.add_argument(
        '--random_seed', default=1, type=bool, help='Manually set random seed of sampling validation clip')
    
    args = parser.parse_args(args = [])

    return args

In [None]:
class VideoLogitDataset(Dataset):

    def __init__(self, video_dir_path, logits_file, transform=None):

        self.video_dir_path = video_dir_path
        self.instances = []  # Tensor of image frames
        self.logits = pickle.load(open(logits_file, 'rb'))
        self.logits = (self.logits).numpy()
        self.logits = torch.tensor(self.logits)

        self.videos = sorted([str(x.name) for x in Path(self.video_dir_path).iterdir() if x.is_dir()])
        self.get_frames()

        self.instances = torch.stack(self.instances)
        self.num_instances = len(self.instances)
        self.transform = transform

    def get_frames(self):
        for video in tqdm(self.videos, position=0, leave=True):
            
            image_frames = []
            video_dir = os.path.join(self.video_dir_path, video)
            if video_dir == '/content/k400_16_frames_uniform/classes.csv' or video_dir == '/content/k400_16_frames_uniform/labels.csv':
              continue
            images = os.listdir(video_dir)

            for image_name in images:
                image = Image.open(os.path.join(video_dir, image_name))
                # image = np.array(image, dtype=np.float32)
                newsize = (224,224)
                image = image.resize(newsize)
                image = np.array(image, dtype=np.float32)
                image = image / 255.0
                image_frames.append(torch.tensor(image))

            self.instances.append(torch.stack(image_frames))

    def __getitem__(self, idx):
        vid = self.instances[idx]
        vid = vid.swapaxes(0, 3)
        if self.transform:
            vid = self.transform(vid)
        return vid, self.logits[idx]

    def __len__(self):
        return len(self.instances)


class ValDataset(Dataset):

    def __init__(self, video_dir_path, classes_file, labels_file, num_classes, transform=None):

        self.video_dir_path = video_dir_path
        self.classes_file = classes_file
        self.labels_file = labels_file
        self.transform = transform

        self.videos = sorted([str(x.name) for x in Path(self.video_dir_path).iterdir() if x.is_dir()])
        self.num_instances = len(self.videos)
        self.num_classes = num_classes

        self.label_dict = pd.read_csv(self.labels_file, header=None, index_col=1, squeeze=False).to_dict()
        self.label_dict = self.label_dict[0]

        self.classes_dict = pd.read_csv(self.classes_file, header=None, index_col=1, squeeze=False).to_dict()
        self.classes_dict = self.classes_dict[0]

        self.new_classes_dict = {}
        for index, (id, label) in enumerate(self.classes_dict.items()):
            if index == 0:
                continue
            self.new_classes_dict[id] = self.label_dict[label]
        print(self.new_classes_dict)
        print(len(self.new_classes_dict))

    def get_id(self, video_name):
        k = 0
        rev = video_name[::-1]
        for x in range(len(video_name)):
            if rev[x] == '_':
                k = k + 1
            if k >= 2:
                k = x
                break

        id = video_name[0:len(video_name) - k - 1]

        return id

    def get_label(self, idx):
        video_name = self.videos[idx]
        video_id = self.get_id(video_name)
        label = self.new_classes_dict[video_id]
        one_hot = F.one_hot(torch.tensor(int(label)), self.num_classes)
        return one_hot

    def get_frames(self, video_path):
        images = sorted(os.listdir(video_path))
        image_frames = []

        for image_name in images:
            image = Image.open(os.path.join(video_path, image_name))
            newsize = (224,224)
            image = image.resize(newsize)
            image = np.array(image, dtype=np.float32)
            image = image / 255.0
            image_frames.append(torch.tensor(image))

        return torch.stack(image_frames)

    def __getitem__(self, idx):
        video_path = os.path.join(self.video_dir_path, self.videos[idx])
        vid = self.get_frames(video_path)
        if self.transform:
            vid = vid.permute(0, 3, 1, 2)
            vid = self.transform(vid)
            vid = vid.permute(0, 2, 3, 1)
        vid = vid.swapaxes(0, 3)  # <C3D Transform>
        return vid, self.get_label(idx)

    def __len__(self):
        return self.num_instances

class ValDataset_var_frames(Dataset):

    def __init__(self, video_dir_path, classes_file, labels_file, num_classes,frames_needed, transform=None):

        self.video_dir_path = video_dir_path
        self.classes_file = classes_file
        self.labels_file = labels_file
        self.transform = transform
        self.frames_needed = frames_needed

        self.videos = sorted([str(x.name) for x in Path(self.video_dir_path).iterdir() if x.is_dir()])
        self.num_instances = len(self.videos)
        self.num_classes = num_classes

        self.label_dict = pd.read_csv(self.labels_file, header=None, index_col=1, squeeze=False).to_dict()
        self.label_dict = self.label_dict[0]

        self.classes_dict = pd.read_csv(self.classes_file, header=None, index_col=1, squeeze=False).to_dict()
        self.classes_dict = self.classes_dict[0]

        self.new_classes_dict = {}
        for index, (id, label) in enumerate(self.classes_dict.items()):
            if index == 0:
                continue
            self.new_classes_dict[id] = self.label_dict[label]
        print(self.new_classes_dict)
        print(len(self.new_classes_dict))

    def get_id(self, video_name):
        k = 0
        rev = video_name[::-1]
        for x in range(len(video_name)):
            if rev[x] == '_':
                k = k + 1
            if k >= 2:
                k = x
                break

        id = video_name[0:len(video_name) - k - 1]

        return id

    def get_label(self, idx):
        video_name = self.videos[idx]
        video_id = self.get_id(video_name)
        label = self.new_classes_dict[video_id]
        one_hot = F.one_hot(torch.tensor(int(label)), self.num_classes)
        return one_hot

    def get_frames(self, video_path):
        images = sorted(os.listdir(video_path))
        image_frames = []
        frames_avlb = len(images)
        print('frame_avlb',frames_avlb)
        if frames_avlb == 0:
          image = np.ones((224,224))
          for i in range(self.frames_needed):
            image_frames.append(torch.tensor(image))
          return torch.stack(image_frames)
        reps = int(self.frames_needed/frames_avlb)
        rem = self.frames_needed % frames_avlb
        rem = reps + rem
        
        if frames_avlb >=self.frames_needed :
          images = images[:self.frames_needed]
          for image_name in images:
            image = Image.open(os.path.join(video_path, image_name))
            newsize = (224,224)
            image = image.resize(newsize)
            image = np.array(image, dtype=np.float32)
            image = image / 255.0
            image_frames.append(torch.tensor(image))
        else : 
          print('here')
          c=0
          for image_name in images:
              image = Image.open(os.path.join(video_path, image_name))
              newsize = (224,224)
              image = image.resize(newsize)
              image = np.array(image, dtype=np.float32)
              image = image / 255.0
              c=c+1
              if (c==0):
                for i in range(rem):
                  image_frames.append(torch.tensor(image))
              else:
                for i in range(reps):
                  image_frames.append(torch.tensor(image))

        return torch.stack(image_frames)

    def __getitem__(self, idx):
        video_path = os.path.join(self.video_dir_path, self.videos[idx])
        vid = self.get_frames(video_path)
        if self.transform:
            vid = vid.permute(0, 3, 1, 2)
            vid = self.transform(vid)
            vid = vid.permute(0, 2, 3, 1)
        vid = vid.swapaxes(0, 3)  # <C3D Transform>
        return vid, self.get_label(idx)

    def __len__(self):
        return self.num_instances

def extrapolate(input_dir, output_dir, out_frames: int = 16):
    input_dir = Path(input_dir)
    output_dir = Path(output_dir)
    error_count = 0
    videos = sorted(os.listdir(input_dir))

    for video in tqdm(videos):

        frames = sorted(os.listdir(input_dir / video))
        if len(frames) == 0:
            print(f'----> Skipping {video}: video has no frames')
            continue
        else:
            frames = sorted(frames * (out_frames // len(frames)))
            new_vid_frames = frames
            length = len(new_vid_frames)
            add_frames = out_frames % length
            x = length // (add_frames + 1)
            a = x

        if add_frames % 2 != 0:
            new_vid_frames.append(frames[length // 2])
            add_frames = add_frames - 1

        while add_frames != 0:
            new_vid_frames.append(frames[a - 1])
            new_vid_frames.append(frames[length - a])
            a = a + x
            add_frames = add_frames - 2

        new_vid_frames.sort()

        out_path = output_dir / video
        out_path.mkdir(parents=True, exist_ok=True)

        k = 0
        for idx, frame in enumerate(new_vid_frames):
            src = input_dir / video / frame
            dst = output_dir / video / (str(idx) + ".jpg")
            shutil.copy(src, dst)
        if len(os.listdir(output_dir / video)) != 16:
            print(len(new_vid_frames))
            error_count += 1
    if error_count > 0:
        print(f'----> {error_count} videos were not copied')
    else:
        print('----> All videos were copied')
class VideoLogitDataset_noise(Dataset):

    def __init__(self, video_dir_path, csv_file):
        self.video_dir_path = video_dir_path
        self.csv = pd.read_csv(csv_file)

        self.videos = self.csv['FileNames']
        self.videos = [self.video_dir_path + x + '.pkl' for x in self.videos]
        self.labels = self.csv['Labels']

        self.num_instances = len(self.csv)

    def __getitem__(self, idx):
        frames = torch.tensor(pickle.load(open(self.videos[idx], 'rb'))).to(torch.float32)
        label = np.zeros((600))
        label[self.labels[idx]] = 1
        label = torch.tensor(label)
        return frames, label

    def __len__(self):
        return (self.num_instances)

class VideoLogitDataset_cgan_noise(Dataset):

    def __init__(self, video_dir_path, logits_file):
        self.video_dir_path = video_dir_path
        self.num_instances = len(os.listdir(self.video_dir_path))
        self.logits = pickle.load(open(logits_file, 'rb'))
        self.logits = (self.logits).numpy()
        self.logits = torch.tensor(self.logits)

        
    def __getitem__(self, idx):
        frame = np.array(Image.open(self.video_dir_path+"/"+str(idx)+"/"+str(idx)+".png"))
        frame = torch.tensor(frame.reshape((frame.shape[2],1,frame.shape[0],frame.shape[1])))
        return frame.to(torch.float32), self.logits[idx]

    def __len__(self):
        return (self.num_instances)

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


class Logger(object):

    def __init__(self, path, header, resume_path, begin_epoch):
        if (not os.path.exists(path)) or (resume_path==''):
            self.log_file = open(path, 'w+')
            self.logger = csv.writer(self.log_file, delimiter='\t')
            self.logger.writerow(header)
        else:
            self.log_file = open(path, 'r+')
            self.logger = csv.writer(self.log_file, delimiter='\t')
            reader = csv.reader(self.log_file, delimiter='\t')
            lines = []
            print("begin = ", begin_epoch)
            for line in reader:
                lines.append(line)
                if len(lines) == begin_epoch +1 :
                    break
            self.log_file.close()
            self.log_file = open(path, 'w')
            self.logger = csv.writer(self.log_file, delimiter='\t')
            self.logger.writerows(lines[:begin_epoch+1])
            self.log_file.flush()
            
        self.header = header


    def __del(self):
        self.log_file.close()

    def log(self, values):
        write_values = []
        for col in self.header:
            assert col in values
            write_values.append(values[col])

        self.logger.writerow(write_values)
        self.log_file.flush()

class Logger_MARS(object):

    def __init__(self, path, header, resume_path, begin_epoch):
        if resume_path == '':
            self.log_file = open(path, 'w+')
            self.logger = csv.writer(self.log_file, delimiter='\t')
            self.logger.writerow(header)
        else:
            self.log_file = open(path, 'r+')
            self.logger = csv.writer(self.log_file, delimiter='\t')
            reader = csv.reader(self.log_file, delimiter='\t')
            lines = []
            print("begin = ", begin_epoch)
            for line in reader:
                lines.append(line)
                if len(lines) == begin_epoch +1 :
                    break
            self.log_file.close()
            self.log_file = open(path, 'w')
            self.logger = csv.writer(self.log_file, delimiter='\t')
            self.logger.writerows(lines[:begin_epoch+1])
            self.log_file.flush()

        self.header = header


    def __del(self):
        self.log_file.close()

    def log(self, values):
        write_values = []
        for col in self.header:
            assert col in values
            write_values.append(values[col])

        self.logger.writerow(write_values)
        self.log_file.flush()


def load_value_file(file_path):
    with open(file_path, 'r') as input_file:
        value = float(input_file.read().rstrip('\n\r'))

    return value

def calculate_accuracy(outputs, targets):
    batch_size = targets.size(0)
    _, pred = outputs.topk(1, 1, True)
    pred = pred.t()
    _, targets = targets.topk(1, 1, True)
    targets = targets.t()
    correct = pred.eq(targets.view(1, -1))

    n_correct_elems = correct.float().sum().item()
    

    return n_correct_elems / batch_size

def calculate_accuracy5(output, target, topk=5):
    """Computes the precision@k for the specified values of k"""
    k = topk
    batch_size = target.size(0)

    _, pred = output.topk(k, 1, True, True)
    pred = pred.t()
    _, targets = target.topk(1, 1, True)
    targets = targets.t()
    correct = pred.eq(targets.view(1, -1).expand_as(pred))
    correct_k = correct[:k].view(-1).float().sum(0)
    return correct_k.mul_(1.0 / batch_size)
    
def calculate_accuracy_video(output_buffer, i):
    true_value = output_buffer[: i+1,-1]
    pred_value = np.argmax(output_buffer[:i+1, :-1], axis = 1)
#    print(output_buffer[0:3,:])
    # print(true_value)
    # print(pred_value)
    # print("accuracy = ", 1*(np.equal(true_value, pred_value)).sum()/len(true_value))
    return 1*(np.equal(true_value, pred_value)).sum()/len(true_value)


In [None]:
args = {'wandb_api_key': 'f6dd820ca08b228c7004a5478d1d3ccd01fcbea2','wandb_name' : 'grey_box_ucf_pretrained_k600_training', 'project' : 'model_extraction'}

In [None]:
opt = parse_opts()
print(opt)

opt.arch = '{}-{}'.format(opt.model, opt.model_depth)
torch.manual_seed(opt.manual_seed)

if opt.modality=='RGB': opt.input_channels = 3
elif opt.modality=='Flow': opt.input_channels = 2

print("Loading model... ", opt.model, opt.model_depth)
model, parameters = generate_model(opt)

In [None]:
init_wandb(model,args)

In [None]:
!cp /content/drive/MyDrive/kinetics_final/k400/k400_train_5_percent_16_frames_uniform.zip /content/
!unzip -q /content/k400_train_5_percent_16_frames_uniform.zip

In [None]:
!cp /content/drive/MyDrive/kinetics_final/k400/k400_val_16_frames_uniform.zip /content/
!unzip -q /content/k400_val_16_frames_uniform.zip

In [None]:
finetune_dataset_aug = VideoLogitDataset('/content/Videos', '/content/drive/MyDrive/annotations/train_k400_5_percent_list.txt', '/content/drive/MyDrive/logits/swin_transformer_k400_train5precent_2.pkl')
val_dataset = ValDataset('/content/k400_16_frames_uniform', '/content/k400_16_frames_uniform/classes.csv', '/content/k400_16_frames_uniform/labels.csv', 400)

In [None]:
train_data = finetune_dataset
val_data = val_dataset

In [None]:
print("Preparing datatloaders ...")
train_dataloader = DataLoader(train_data, batch_size = opt.batch_size, shuffle=True, num_workers = opt.n_workers, pin_memory = True)
val_dataloader   = DataLoader(val_data, batch_size = opt.batch_size, shuffle=True, num_workers = opt.n_workers, pin_memory = True)
print("Length of train datatloader = ",len(train_dataloader))
print("Length of validation datatloader = ",len(val_dataloader))

In [None]:
if opt.pretrain_path: 
    opt.weight_decay = 1e-5
    opt.learning_rate = 0.001

if opt.nesterov: dampening = 0
else: dampening = opt.dampening
    
print("lr = {} \t momentum = {} \t dampening = {} \t weight_decay = {}, \t nesterov = {}"
            .format(opt.learning_rate, opt.momentum, dampening, opt. weight_decay, opt.nesterov))
print("LR patience = ", opt.lr_patience)


optimizer = optim.SGD(
    parameters,
    lr=opt.learning_rate,
    momentum=opt.momentum,
    dampening=dampening,
    weight_decay=opt.weight_decay,
    nesterov=opt.nesterov)

In [None]:
criterion = nn.CrossEntropyLoss().cuda()

In [None]:
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=opt.lr_patience)

In [None]:
print('run')
best_val_acc = 0
for epoch in range(1, opt.n_epochs + 1):

  model.train()

  batch_time = AverageMeter()
  data_time = AverageMeter()
  losses = AverageMeter()
  accuracies = AverageMeter()

  end_time = time.time()

  for i, (inputs, targets) in enumerate(train_dataloader):
      data_time.update(time.time() - end_time)
      inputs = torch.permute(inputs,(0,1,4,2,3))
      targets = torch.nn.functional.softmax(targets)
      targets = targets.to(torch.float32)
      targets = targets.cuda(non_blocking=True)
      inputs = Variable(inputs)
      targets = Variable(targets)
      outputs = model(inputs)

      loss = criterion(outputs, targets)
      acc = calculate_accuracy(outputs, targets)

      losses.update(loss.item(), inputs.size(0))
      accuracies.update(acc, inputs.size(0))

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      batch_time.update(time.time() - end_time)
      end_time = time.time()

      print('Epoch: [{0}][{1}/{2}]\t'
            'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
            'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
            'Aug Loss {loss.val:.4f} ({loss.avg:.4f})\t'
            'Aug Acc {acc.val:.3f} ({acc.avg:.3f})'.format(
                epoch,
                i + 1,
                len(train_dataloader),
                batch_time=batch_time,
                data_time=data_time,
                loss=losses,
                acc=accuracies))
      wandb.log({
        'Training loss': losses.avg,
        'Training Accuracy': accuracies.avg,}, step=epoch)
  

  model.eval()

  batch_time = AverageMeter()
  data_time = AverageMeter()
  losses = AverageMeter()
  accuracies = AverageMeter()

  end_time = time.time()
  with torch.no_grad():
      for i, (inputs, targets) in enumerate(val_dataloader):
          inputs = torch.permute(inputs,(0,1,4,2,3))
          # targets = torch.nn.functional.softmax(targets)
          targets = targets.to(torch.float32)
          # pdb.set_trace()
          data_time.update(time.time() - end_time)
          targets = targets.cuda(non_blocking=True)
          inputs = Variable(inputs)
          targets = Variable(targets)
          outputs = model(inputs)
          loss = criterion(outputs, targets)
          acc = calculate_accuracy(outputs, targets)
      
          losses.update(loss.item(), inputs.size(0))
          accuracies.update(acc, inputs.size(0))

          batch_time.update(time.time() - end_time)
          end_time = time.time()

          print('Val_Epoch: [{0}][{1}/{2}]\t'
              'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
              'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
              'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
              'Acc {acc.val:.3f} ({acc.avg:.3f})'.format(
                  epoch,
                  i + 1,
                  len(val_dataloader),
                  batch_time=batch_time,
                  data_time=data_time,
                  loss=losses,
                  acc=accuracies))
          wandb.log({
        'Validation loss': losses.avg,
        'Validation Accuracy': accuracies.avg,}, step=epoch)
  if(accuracies.avg > best_val_acc):
        best_val_acc = accuracies.avg
        torch.save(model.module.state_dict(),'/checkpoints/best_model.pth')

  if epoch%2==0:
    torch.save(model.module.state_dict(),'/checkpoints/'+ str(epoch) +'.pth')