In [1]:
import argparse
import os
import random
import shutil
import time
import warnings
import sys
sys.path.append('../')

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
# import torchvision.models as models
from resnet import *

from main import *

model_names = sorted(name for name in models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models.__dict__[name]))

In [2]:
# args = parser.parse_args(args=[])
# args = parser.parse_args()
import easydict 
args = easydict.EasyDict({ "batch-size": 256, 
                          "epochs": 50, 
                          "data": 0, 
                          'arch':'resnet18',
                          'lr':0.1,
                         'momentum':0.9,
                         'weight_decay':1e-4,
                         'start_epoch':0,
                         'gpu':0})


In [3]:
ngpus_per_node = torch.cuda.device_count()
print(ngpus_per_node)
# device = 'cpu'
device = 'cuda'

1


In [4]:
import numpy as np
imagenet_embeding = np.load('../data/imagenet_embeding.npy')
imagenet_embeding.shape

(1000, 768)

In [5]:
imagenet_embeding = torch.tensor(imagenet_embeding)
imagenet_embeding.shape

torch.Size([1000, 768])

In [6]:
print("=> using pre-trained model '{}'".format('resnet18'))
# model = models.__dict__['resnet18'](pretrained=True)
# model = models.resnet18(pretrained=False)
model = resnet18(pretrained=False)
# model.to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
# optimizer = torch.optim.Adam(
#     model.parameters(), lr=0.1)


=> using pre-trained model 'resnet18'


In [7]:
model.fc.weight[0][0]

tensor(-0.0136, grad_fn=<SelectBackward>)

In [8]:
# torch.save(model.state_dict(), '../trained_model/init.pt')
model.load_state_dict(torch.load('../trained_model/init.pt'))

<All keys matched successfully>

In [9]:
model.fc.weight[0][0]

tensor(-0.0301, grad_fn=<SelectBackward>)

In [10]:
model_dict = model.state_dict() 
for k in model_dict :
    if 'fc.weight' in k :
        print(model_dict[k].shape)
        print(model_dict[k])
        model_dict[k] = imagenet_embeding
#     if 'fc.bias' in k :
#         print(model_dict[k])        
#     print(model_dict[k].shape)
model.load_state_dict(model_dict)
model_dict = model.state_dict() 
for k in model_dict :
    if 'fc.weight' in k :
#         print(model_dict[k].shape)
        print(model_dict[k])

torch.Size([1000, 768])
tensor([[-0.0301,  0.0198, -0.0104,  ..., -0.0012, -0.0136,  0.0011],
        [-0.0023, -0.0009, -0.0329,  ..., -0.0281, -0.0132,  0.0052],
        [ 0.0036, -0.0155, -0.0112,  ..., -0.0055,  0.0028,  0.0204],
        ...,
        [-0.0349, -0.0049, -0.0264,  ..., -0.0168, -0.0148,  0.0045],
        [ 0.0063, -0.0293, -0.0185,  ..., -0.0337, -0.0108,  0.0270],
        [-0.0275, -0.0346,  0.0034,  ..., -0.0311,  0.0181,  0.0030]])
tensor([[ 0.5611,  0.2227,  0.2527,  ..., -0.1777,  0.2208,  0.6253],
        [ 0.0625, -0.3026, -0.2862,  ..., -0.0785, -0.1146, -0.1183],
        [ 0.1750,  0.1851, -0.3477,  ..., -0.3395, -0.0346, -0.3034],
        ...,
        [-0.3333,  0.0872, -0.0674,  ...,  0.2487, -0.1052,  0.0210],
        [-0.0106,  0.0222,  0.2349,  ..., -0.4712, -0.5066,  0.3787],
        [ 0.6112,  0.1495, -0.0799,  ..., -0.2983, -0.5161,  0.1615]])


In [11]:
model.fc.weight.requires_grad = False

In [12]:
# for p in model.parameters() :
#     print(p)

In [13]:
# model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [14]:
# Data loading code
data_dir = '../ILSVRC/Data/CLS-LOC/'
traindir = os.path.join(data_dir, 'train')
valdir = os.path.join(data_dir, 'val')

In [15]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

train_dataset = datasets.ImageFolder(
    traindir,
    transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ]))
val_dataset = datasets.ImageFolder(valdir, transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ]))

In [16]:
# next(iter(train_dataset))

In [17]:
# train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_sampler = None

In [18]:
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=256, shuffle=(train_sampler is None),
    num_workers=8, pin_memory=True, sampler=train_sampler)

In [19]:
val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=32, shuffle=False,
    num_workers=4, pin_memory=True)

In [20]:
model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [21]:
best_acc1 = 0
acc1 = 0
train_loss = []
val_acc = []
for epoch in range(args.start_epoch, args.epochs):
    adjust_learning_rate(optimizer, epoch, args)

    # train for one epoch
    epoch_loss = train(train_loader, model, criterion, optimizer, epoch, args)

    # evaluate on validation set
    acc1 = validate(val_loader, model, criterion, args)  
    
    train_loss.append(epoch_loss)
    val_acc.append(acc1)
    print('************train_loss {} val_acc {}*************'.format(epoch_loss, acc1))
    
    # remember best acc@1 and save checkpoint
    is_best = acc1 > best_acc1
    best_acc1 = max(acc1, best_acc1)

#     if not args.multiprocessing_distributed or (args.multiprocessing_distributed
#             and args.rank % ngpus_per_node == 0):
    save_checkpoint({
        'epoch': epoch + 1,
        'arch': args.arch,
        'state_dict': model.state_dict(),
        'best_acc1': best_acc1,
        'optimizer' : optimizer.state_dict(),
    }, is_best)    

Epoch: [0][   0/5005]	Time  2.852 ( 2.852)	Data  2.473 ( 2.473)	Loss 9.6591e+00 (9.6591e+00)	Acc@1   0.00 (  0.00)	Acc@5   0.00 (  0.00)
Epoch: [0][1001/5005]	Time  0.390 ( 0.387)	Data  0.000 ( 0.003)	Loss 6.9821e+00 (7.0464e+00)	Acc@1   0.39 (  0.35)	Acc@5   1.56 (  1.55)
Epoch: [0][2002/5005]	Time  0.390 ( 0.388)	Data  0.000 ( 0.002)	Loss 6.5474e+00 (6.8368e+00)	Acc@1   0.00 (  0.57)	Acc@5   3.52 (  2.36)
Epoch: [0][3003/5005]	Time  0.391 ( 0.388)	Data  0.000 ( 0.001)	Loss 6.3105e+00 (6.6725e+00)	Acc@1   2.73 (  0.86)	Acc@5   7.42 (  3.40)
Epoch: [0][4004/5005]	Time  0.390 ( 0.389)	Data  0.000 ( 0.001)	Loss 6.0042e+00 (6.5211e+00)	Acc@1   3.12 (  1.27)	Acc@5  10.16 (  4.68)
 * Acc@1 3.727 Acc@5 11.476
************train_loss 6.388201045132541 val_acc 3.7274527549743652*************
Epoch: [1][   0/5005]	Time  2.266 ( 2.266)	Data  2.124 ( 2.124)	Loss 5.9463e+00 (5.9463e+00)	Acc@1   3.91 (  3.91)	Acc@5  10.55 ( 10.55)
Epoch: [1][1001/5005]	Time  0.405 ( 0.392)	Data  0.000 ( 0.003)	Loss 