In [2]:
!python --version

Python 3.7.12


In [7]:
from tqdm import tqdm
import ast
import os
import argparse
import torch
from cnnnet import CnnNet
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

In [10]:
# model_load용
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--structure_txt', type=str)
    parser.add_argument('--pretrained', type=str, default=None)
    args = parser.parse_args()
    return args

def get_model(filename,
                pretrained=True,
                network_id=0,
                classification=True #False for detetion
                 ):
    # load best structures
    with open(filename, 'r') as fin:
        content = fin.read()
        output_structures = ast.literal_eval(content)

    network_arch = output_structures['space_arch']
    best_structures = output_structures['best_structures']

    # If task type is classification, param num_classes is required
    out_indices = (1, 2, 3, 4) if not classification else (4, )
    model = CnnNet(
            structure_info=best_structures[network_id],
            out_indices=out_indices,
            num_classes=1000,
            classification=classification)
    model.init_weights(pretrained)

    return model, network_arch

def load_model(model,
               load_parameters_from,
               strict_load=False,
               map_location=torch.device('cpu'),
               **kwargs):
    if not os.path.isfile(load_parameters_from):
        raise ValueError('bad checkpoint to load %s' % (load_parameters_from))
    else:
        model.logger.debug('Zennas: loading params from '
                           + load_parameters_from)
    checkpoint = torch.load(load_parameters_from, map_location=map_location)
    if 'state_dict' in checkpoint:
        state_dict = checkpoint['state_dict']
    else:
        state_dict = checkpoint

    # print("\n#################################")
    # for name, paramets in model.named_parameters():
    # print(name, paramets.size(), paramets.flatten().cpu().detach().numpy()[0:5])
    #     if "conv_offset.weight" in name:
    #         state_dict[name] = state_dict[name.replace(".conv_offset", "")]
    model.load_state_dict(state_dict, strict=strict_load)

    # print("\n#################################")
    # for name, paramets in model.named_parameters():
    #     print(name, paramets.size(), paramets.flatten().cpu().detach().numpy()[0:5])

    return model

In [11]:
def generate_gram_matrix(self, x):
    # x를 2D 텐서로 reshape
    x = x.view(x.size(0), -1)
    
    # Gram 행렬 계산
    gram = torch.matmul(x, x.t())
    
    # Gram 행렬의 대각선을 0으로 설정
    n = gram.size(0)
    gram.fill_diagonal_(0)
    
    # 데이터 타입 변환
    gram = gram.to(self.hsic_accumulator.dtype)
    
    # 평균 계산
    means = gram.sum(dim=0) / (n - 2)
    means -= means.sum() / (2 * (n - 1))
    
    # 평균을 뺀 Gram 행렬 계산
    gram -= means.unsqueeze(1)
    gram -= means.unsqueeze(0)
    
    # 대각선을 다시 0으로 설정
    gram.fill_diagonal_(0)
    
    # 1D 텐서로 변환
    gram = gram.view(-1)
    
    return gram

In [12]:
def update_state(self, activations):
    layer_grams = [self._generate_gram_matrix(x) for x in activations]
    layer_grams = torch.stack(layer_grams, 0)
    print(self.hsic_accumulator.shape)
    self.hsic_accumulator += torch.matmul(layer_grams, layer_grams.t())

In [13]:
if __name__ == '__main__':
    transform_val = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    # make input
    #args = parse_args()
    args = "--structure_txt=best_structure.txt"
    # model_get
    model, network_arch = get_model(args.structure_txt, 
                                            args.pretrained)

    # state load
    model.load_state_dict(torch.load('./tinynas/deploy/cnnnet/tinynascnnex1_statedict_1.pth'))

    # GPU 사용 설정
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    

usage: ipykernel_launcher.py [-h] [--structure_txt STRUCTURE_TXT]
                             [--pretrained PRETRAINED]
ipykernel_launcher.py: error: unrecognized arguments: -f /home/deeparc/.local/share/jupyter/runtime/kernel-3adaa313-e47a-4bff-99c0-46e358246f41.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
