In [51]:
import torchsummary
import loader
import research_models
from importlib import reload
import wandb
wandb.login()

In [124]:
test = loader.TinyImageNet200()
test.prepare(split=loader.VALID)
x = next(iter(test.dataloader(batch_size=1000)))
step  = 1

In [112]:
for conv_kernel_size in [2, 3]:
    for pool_kernel_size in [2, 3]:
        for pool_stride in [2, 3]:
            config = {}
            config['conv_kernel_size'] = conv_kernel_size
            config['pool_kernel_size'] = pool_kernel_size
            config['pool_stride'] = pool_stride
            config['linear_output'] = get_linear_output_dict(**config)
            print(config)
            y = research_models.ResidualNet(**config)(x).argmax(1)
            print(y.shape)

{'conv_kernel_size': 2, 'pool_kernel_size': 2, 'pool_stride': 2, 'linear_output': 2048}
torch.Size([1000])
{'conv_kernel_size': 2, 'pool_kernel_size': 3, 'pool_stride': 2, 'linear_output': 1152}
torch.Size([1000])
{'conv_kernel_size': 3, 'pool_kernel_size': 2, 'pool_stride': 2, 'linear_output': 2048}
torch.Size([1000])
{'conv_kernel_size': 3, 'pool_kernel_size': 3, 'pool_stride': 2, 'linear_output': 1152}
torch.Size([1000])


In [144]:
# wandb.init(project='test', entity='adnanhd', name='bar')
config = {'conv_kernel_size': 2, 'pool_kernel_size': 2, 'pool_stride': 2, 'linear_output': 2048}
model = research_models.ResidualNet(**config)

for x, y_true in test.dataloader(batch_size=1000):
    step += 1
    wandb.define_metric('test_step')
    wandb.define_metric('test/accuracy', step_metric='test_step')
    y_pred = model(x).detach().numpy().argmax(1)
    cm = wandb.plot.confusion_matrix(y_true=y_true.numpy(), preds=y_pred, class_names=[str(i) for i in range(200)])
    wandb.log({'conf_mat': cm})

torch.Size([1000]) (1000,)
torch.Size([1000]) (1000,)
torch.Size([1000]) (1000,)
torch.Size([1000]) (1000,)
torch.Size([1000]) (1000,)
torch.Size([1000]) (1000,)
torch.Size([1000]) (1000,)
torch.Size([1000]) (1000,)
torch.Size([1000]) (1000,)
torch.Size([1000]) (1000,)


In [100]:
linear_output_dict = {}
def get_linear_output_dict(conv_kernel_size=3, pool_kernel_size=3, pool_stride=3, **kwargs):
    try:
        return linear_output_dict[(conv_kernel_size, pool_kernel_size, pool_stride)]
    except KeyError:
        return 2048

def set_linear_output_dict(linear_output, conv_kernel_size=3, pool_kernel_size=3, pool_stride=3):
    linear_output_dict[(conv_kernel_size, pool_kernel_size, pool_stride)] = linear_output
    


experiments = [
    dict(pool='avg', batch_norm=False, conv_kernel_size=3, pool_kernel_size=3, pool_stride=3, linear_output=128),
    dict(pool='avg', batch_norm=False, conv_kernel_size=3, pool_kernel_size=2, pool_stride=2, linear_output=2048),
    dict(pool='avg', batch_norm=False, conv_kernel_size=2, pool_kernel_size=3, pool_stride=3, linear_output=512),
    dict(pool='avg', batch_norm=False, conv_kernel_size=2, pool_kernel_size=2, pool_stride=2, linear_output=2048),
    dict(pool='max', batch_norm=False, conv_kernel_size=3, pool_kernel_size=3, pool_stride=3, linear_output=128),
    dict(pool='max', batch_norm=False, conv_kernel_size=3, pool_kernel_size=2, pool_stride=2, linear_output=2048),
    dict(pool='max', batch_norm=False, conv_kernel_size=2, pool_kernel_size=3, pool_stride=3, linear_output=512),
    dict(pool='max', batch_norm=False, conv_kernel_size=2, pool_kernel_size=2, pool_stride=2, linear_output=2048),
    
    dict(pool='avg', residual=True, conv_kernel_size=3, pool_kernel_size=3, pool_stride=3, linear_output=128),
    dict(pool='avg', residual=True, conv_kernel_size=3, pool_kernel_size=2, pool_stride=2, linear_output=2048),
    dict(pool='avg', residual=True, conv_kernel_size=2, pool_kernel_size=3, pool_stride=3, linear_output=512),
    dict(pool='avg', residual=True, conv_kernel_size=2, pool_kernel_size=2, pool_stride=2, linear_output=2048),
    dict(pool='max', residual=True, conv_kernel_size=3, pool_kernel_size=3, pool_stride=3, linear_output=128),
    dict(pool='max', residual=True, conv_kernel_size=3, pool_kernel_size=2, pool_stride=2, linear_output=2048),
    dict(pool='max', residual=True, conv_kernel_size=2, pool_kernel_size=3, pool_stride=3, linear_output=512),
    dict(pool='max', residual=True, conv_kernel_size=2, pool_kernel_size=2, pool_stride=2, linear_output=2048),
    
    dict(pool='avg', batch_norm=False, conv_kernel_size=3, pool_kernel_size=3, pool_stride=3, linear_output=128),
    dict(pool='avg', batch_norm=False, conv_kernel_size=3, pool_kernel_size=2, pool_stride=2, linear_output=2048),
    dict(pool='avg', batch_norm=False, conv_kernel_size=2, pool_kernel_size=3, pool_stride=3, linear_output=512),
    dict(pool='avg', batch_norm=False, conv_kernel_size=2, pool_kernel_size=2, pool_stride=2, linear_output=2048),
    dict(pool='max', batch_norm=False, conv_kernel_size=3, pool_kernel_size=3, pool_stride=3, linear_output=128),
    dict(pool='max', batch_norm=False, conv_kernel_size=3, pool_kernel_size=2, pool_stride=2, linear_output=2048),
    dict(pool='max', batch_norm=False, conv_kernel_size=2, pool_kernel_size=3, pool_stride=3, linear_output=512),
    dict(pool='max', batch_norm=False, conv_kernel_size=2, pool_kernel_size=2, pool_stride=2, linear_output=2048),
    
    dict(pool='avg', batch_norm=True, conv_kernel_size=3, pool_kernel_size=3, pool_stride=3, linear_output=128),
    dict(pool='avg', batch_norm=True, conv_kernel_size=3, pool_kernel_size=2, pool_stride=2, linear_output=2048),
    dict(pool='avg', batch_norm=True, conv_kernel_size=2, pool_kernel_size=3, pool_stride=3, linear_output=512),
    dict(pool='avg', batch_norm=True, conv_kernel_size=2, pool_kernel_size=2, pool_stride=2, linear_output=2048),
    dict(pool='max', batch_norm=True, conv_kernel_size=3, pool_kernel_size=3, pool_stride=3, linear_output=128),
    dict(pool='max', batch_norm=True, conv_kernel_size=3, pool_kernel_size=2, pool_stride=2, linear_output=2048),
    dict(pool='max', batch_norm=True, conv_kernel_size=2, pool_kernel_size=3, pool_stride=3, linear_output=512),
    dict(pool='max', batch_norm=True, conv_kernel_size=2, pool_kernel_size=2, pool_stride=2, linear_output=2048),
]