In [None]:
!git clone https://github.com/Swastik166/TANS_trial.git
    

In [None]:
import sys
sys.path.insert(0, './TANS_trial')
!pip install config

In [None]:
%%bash
cd TANS_trial
python3 main.py --gpu 0 --mode train --batch-size 140 --n-epochs 10000 --base-path path/for/storing/outcomes/ --data-path path/to/processed/dataset/is/stored/ --model-zoo path/to/model_zoo.pt --seed 777 

In [None]:
import torch
import numpy as np
import random
import torchvision.models as models
from tqdm import tqdm

"""add utils from misc and use the finctions available in for better error handling"""

class ModelZoo:
    def __init__(self, datasets, seed):
        self.seed = seed
        self.datasets = datasets
        self.models = {}
        self.train_instances = {}
        self.noise = torch.load('/path/to/noise.pt')
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        torch.cuda.manual_seed(self.seed)
        torch.manual_seed(self.seed)
        np.random.seed(self.seed)
        random.seed(self.seed)


        
        
    def init_loaders(self, dataset):
        # Get loaders for train, test, and validation data
        self.tr_dataset, self.tr_loader = get_loader(dataset, mode = 'train')
        self.te_dataset, self.te_loader = get_loader(dataset, mode = 'test')
        self.val_dataset, self.val_loader = get_loader(dataset, mode = 'validation')
        self.nclass = self.tr_dataset.get_nclss()
        
        
        
        
    def create_zoo(self):
        zoo = {}

        for dataset in self.datasets:
            # Initializing loaders for the dataset
            self.init_loaders(dataset)


            for _ in range(10):
                # Get neural network model and topology information
                topol, net = self.get_net(nclass)

                # Training the model and obtaining accuracy
                lss = torch.nn.CrossEntropyLoss()
                optim = torch.optim.SGD(net.parameters(), lr=1e-2, momentum=0.9, weight_decay=4e-5)
                scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim , float(self.args.n_eps_finetuning))
                
                acc = self.train(net, dataset, tr_loader, val_loader, nclass)  # Assuming train method returns accuracy

                # Calculating number of parameters
                n_params = self.n_param(net)
                f_emb = self.f_emb(net)

                
                del net
                del optim
                del lss
                
                zoo['dataset'].append(dataset)
                zoo['topol'].append(topol)
                zoo['acc'].append(acc)
                zoo['f_emb'].append(f_emb)
                zoo['n_params'].append(n_params)
                
        self.save_zoo(zoo)
                
                


    def get_net(self, nclss):
        super_net_name = "ofa_supernet_mbv3_w10"

        # Load the super network from the 'mit-han-lab/once-for-all' repository
        super_net = torch.hub.load('mit-han-lab/once-for-all', super_net_name, pretrained=True).eval()

        # Sample an active subnet configuration from the super network
        sampled_config = super_net.sample_active_subnet()

        # Extract topology information from the sampled configuration
        pre_topol = list(sampled_config.values())
        topol = []

        for i in pre_topol:
            for j in i:
                topol.append(j)

        # Split the topology into kernel sizes, expansion ratios, and depths
        ks = topol[:20] 
        e = topol[20:40]
        d = topol[40:]

        # Set the active subnet in the super network using the sampled topology
        super_net.set_active_subnet(ks=ks, e=e, d=d)
        active_subnet = super_net.get_active_subnet(preserve_weight=True)
        active_subnet.classifier = torch.nn.Linear(1536, nclss)
        active_subnet = active_subnet.to(self.device)

        return topol, active_subnet

    
    
    
    def f_emb(self, net):
        model = net        
        module = list(model.children())[:-1]
        model = torch.nn.Sequential(*module)
        model.eval()
        with torch.no_grad():
            # Forward pass through ResNet18
            f_emb = model(self.noise)
        print(f_emb)
        return f_emb

    
    
    

    def train(self, model, dataset, train_loader, val_loader, nclss):
        self.model = model
        print(f'Starting Training for:{dataset} with model topology:{self.topol}')
        lr = lr
        counter = 0
        best_val_loss = 10000
        val_acc = list()
        
        
        for ep in range(self.args.n_eps_finetuning):
            self.curr_ep = ep
            ep_loss_tr = 0.0
            ep_loss_val = 0.0
            ep_tr_time = 0
            
            
            self.model.train()
            for b_id, batch in tqdm(enumerate(self.tr_loader)):
                self.optim.zero_grad()
                st = time.time()
                x,y = batch
                output = self.model(x.to(self.device))
                loss = self.lss(output, y.to(self.device))
                loss.backward()
                self.optim.step()
                self.scheduler.step()
                        
                tr_loss = loss.item()
                        
                ep_loss_tr += tr_loss * x.size(0)
                
            ep_loss_tr = ep_loss_tr/len(self.tr_loader)

                        
            self.model.eval()
            total_val = 0
            crrct_val = 0
                        
                        
            for v_id, (x,y) in tqdm(enumerate(self.val_loader)):
                outputs = self.model(x.to(self.device))
                loss_v = self.lss(outputs, y.to(self.device))
                       
                val_loss = loss_v.detach.item()
                
                ep_loss_val += val_loss * x.size(0)
                        
                pred = torch.argmax(outputs, dim = 1)
                total_val += y.size(0)
                correct_val += (pred == y.to(self.device)).sum().item() 
                        
            acc = (100*correct_val)/total_val
            val_acc.append(acc)
            
            ep_loss_val = ep_loss_val/len(self.val_loader)
    
            if ep_loss_val < best_val_loss:
                best_val_loss = ep_loss_tr
                counter = 0
                
            elif ep_loss_val > min_loss:
                counter += 1
                
            if counter >= patience:
                print(f"Early stopping on, {ep}th, epoch")
                break
                
        return val_acc[-1]
            
            
                
         
            
        dura = time.time() -st
        print(
        f' ==> [dataset:{self.dataset}]'+
        f' ep:{ep+1}, tr_lss:{tr_lss:.3f}, accuracy:{}'+
        f' val_lss:{val_lss:.3f}, val_acc: {val_acc:.3f},'+
        f' tr_time:{ep_tr_time:.3f}s, val_time:{ep_val_time:.3f}s')
            
                        
                        
                        
                        
    
    def n_param(self, model):
        # Calculate the number of parameters in the model
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    

    def get_query(self, dataset):
        model = models.resneth18(pretrained = True)
        module = list(model.children())[:-1]
        model = torch.nn.Sequential(*module)
        model.eval()
        with torch.no_grad():
            images, _ =  next(iter(self.tr_loader))
            print(len(images))
            # Forward pass through ResNet18
            query_train = model(images)
        print (query_train.shape)
        with torch.no_grad():
            images, _ =  next(iter(self.te_loader))
            print(len(images))
            # Forward pass through ResNet18
        query_test = model(images)
        print (query_test.shape)
        return query_train, query_test
    
    
    

    def save_zoo(self, dict):
        # Save model zoo to model_zoo.pt

        torch.save(dict, 'model_zoo.pt')

        
        
    def save_trainpt(self):
        # Save training instances to m_train.pt
        trainpt_dict = {}
        for dataset in self.datasets:
            trainpt_dict[dataset] = {
                'task': {},  # Placeholder for task information
                'clss': {},  # Placeholder for class information
                'nclss': {},  # Placeholder for number of classes
                'x_query_train': {},  # Placeholder for training query data
                'x_query_test': {}  # Placeholder for testing query data
            }
        torch.save(trainpt_dict, 'm_train.pt')


In [1]:
model_zoo = {'dataset':['dataset1', 'dataset1', 'dataset2', 'dataset2', 'dataset3', 'dataset3'],
             'topol':['model1', 'model2', 'model3', 'model4' , 'model5',  'model6'],
             'acc':['acc1', 'acc2', 'acc3', 'acc4', 'acc5', 'acc6'],
             'f_emb':['f_emb1', 'f_emb2', 'f_emb3', 'f_emb4', 'f_emb5', 'f_emb6'],
             'n_params':['np1', 'np2', 'np3', 'np4', 'np5', 'np6']}

In [2]:
model_zoo

{'dataset': ['dataset1',
  'dataset1',
  'dataset2',
  'dataset2',
  'dataset3',
  'dataset3'],
 'topol': ['model1', 'model2', 'model3', 'model4', 'model5', 'model6'],
 'acc': ['acc1', 'acc2', 'acc3', 'acc4', 'acc5', 'acc6'],
 'f_emb': ['f_emb1', 'f_emb2', 'f_emb3', 'f_emb4', 'f_emb5', 'f_emb6'],
 'n_params': ['np1', 'np2', 'np3', 'np4', 'np5', 'np6']}

In [3]:
import torch 
torch.save(model_zoo, 'zoo.pt')

In [4]:
zoo = torch.load('zoo.pt')
print(zoo)
print(type(zoo))

{'dataset': ['dataset1', 'dataset1', 'dataset2', 'dataset2', 'dataset3', 'dataset3'], 'topol': ['model1', 'model2', 'model3', 'model4', 'model5', 'model6'], 'acc': ['acc1', 'acc2', 'acc3', 'acc4', 'acc5', 'acc6'], 'f_emb': ['f_emb1', 'f_emb2', 'f_emb3', 'f_emb4', 'f_emb5', 'f_emb6'], 'n_params': ['np1', 'np2', 'np3', 'np4', 'np5', 'np6']}
<class 'dict'>


In [6]:
zoo['dataset'].append('test')
zoo['topol'].append('test')
zoo['acc'].append('test')
zoo['f_emb'].append('test')
zoo['n_params'].append('test')
print(zoo)

{'dataset': ['dataset1', 'dataset1', 'dataset2', 'dataset2', 'dataset3', 'dataset3', 'test'], 'topol': ['model1', 'model2', 'model3', 'model4', 'model5', 'model6', 'test'], 'acc': ['acc1', 'acc2', 'acc3', 'acc4', 'acc5', 'acc6', 'test'], 'f_emb': ['f_emb1', 'f_emb2', 'f_emb3', 'f_emb4', 'f_emb5', 'f_emb6', 'test'], 'n_params': ['np1', 'np2', 'np3', 'np4', 'np5', 'np6', 'test']}
