In [None]:
!nvidia-smi -L

In [None]:
# 自动下载数据集

import os
# TinyImageNet
if not os.path.exists("./TinyImageNet"):
  !kaggle datasets download -d mikewallace250/tiny-imagenet-challenge
  !unzip tiny-imagenet-challenge.zip > /dev/null
  !rm tiny-imagenet-challenge.zip

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
img = mpimg.imread('TinyImageNet/val/6/6_1087.jpg').astype(int)
print(img.shape)
plt.imshow(img)

In [None]:
# custom transform objects
import torch
from torchvision import datasets, transforms
import random

class MyNormalize(object):
    """Normalize a image tensor
    """

    def __call__(self, data): # assumes the same shape as an image tensor
      assert(len(data.shape) == 3)
      # mean = []
      # std = []
      # for i in range(3):
      #   mean.append(torch.mean(data[i]))
      #   std.append(torch.std(data[i]))
      # transforms.Normalize(mean, std, inplace=True)(data)
      data = data / 255
      transforms.Normalize(mean = (0.4650,0.4516,0.3871),std=(0.2671,0.2579,0.2722), inplace=True)(data)
      return data

class MyAugmentation(object):
    """Augment a image tensor or a batch of image tensors
        Uses the below transforms randomly
    """

    def __call__(self, image): 
      image = transforms.RandomHorizontalFlip()(image)
      image = transforms.RandomGrayscale()(image)
      if random.random() < 0.5:
        image = transforms.RandomRotation(15)(image)
      # I didn't add cropping because it might hurt ImageNet-A performance 
      # image = image.RandomCrop(64, pad_if_needed=True)
      return image
      

In [None]:
# load_data.py
from torchvision import datasets, transforms
from torch import nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import random
import time

class TINDataset(torch.utils.data.Dataset):
  """ TinyImageNet Dataset """
  TRAIN = 0
  VALIDATE = 1
  TEST = 9
  _status = TRAIN # mark the current state of the dataset
  
  def __init__(self, args):
    self.train_loc = args.train_loc
    self.val_loc = args.val_loc
    self.test_loc = args.test_loc

    self.train_label = torch.mul(torch.LongTensor([i for i in range(100)]).view(-1,1), torch.ones(1000).long().view(1,-1)).view(-1)
    self.val_label = torch.mul(torch.LongTensor([i for i in range(200)]).view(-1,1), torch.ones(100).long().view(1,-1)).view(-1)

    # 对输入矩阵进行转置、转化为tensor、中心化
    #trans = transforms.Compose([transforms.ToTensor()])
    #trans = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean = (118.57,115.17,98.72),std=(68.10,65.77,69.41))])
    trans = transforms.Compose([transforms.ToTensor(), MyNormalize()])
    time_start = 0
    time_end = 0
    print("----------Loading training data----------")
    time_start = time.time()
    print('train_loc:',self.train_loc)
    self.train_data = torch.zeros(100 * 1000, 3, 64, 64)
    for i in range(100):
      for j in range(1000):
        filename = self.train_loc+str(i)+'/'+str(i)+'_'+str(j)+'.jpg'
        self.train_data[i * 1000 + j] = trans(mpimg.imread(filename).astype(float))
    print('size:',self.train_data.shape)
    time_end = time.time()
    print("Total time: {:.2f}min".format((time_end - time_start) / 60))
    print("--------------Finish loading--------------")

    print("----------Loading validation data----------")
    time_start = time.time()
    print('validation_loc:',self.val_loc)
    self.val_data = torch.zeros(100 * 100, 3, 64, 64)
    for i in range(100):
      for j in range(100):
        filename = self.val_loc+str(i)+'/'+str(i)+'_10'+str(j).rjust(2,'0')+'.jpg'
        self.val_data[i * 100 + j] = trans(mpimg.imread(filename).astype(float))
    
    print('size:',self.val_data.shape)
    time_end = time.time()
    print("Total time: {:.2f}min".format((time_end - time_start) / 60))
    print("--------------Finish loading--------------")

    print("----------Loading test data----------")
    time_start = time.time()
    print('test_loc:',self.test_loc)
    self.test_data = torch.zeros(10000, 3, 64, 64)
    for i in range(10000):
        filename = self.test_loc+str(i)+'.jpg'
        self.test_data[i] = trans(mpimg.imread(filename).astype(float))
    print('size:',self.test_data.shape)
    time_end = time.time()
    print("Total time: {:.2f}min".format((time_end - time_start) / 60))
    print("--------------Finish loading--------------")

  def __getitem__(self, index):
    if self._status == self.TRAIN:
      return self.train_data[index], self.train_label[index] # data, label(target)
    elif self._status == self.VALIDATE:
      return self.val_data[index], self.val_label[index]
    else:
      return self.test_data[index]

  def __len__(self):
    if self._status == self.TRAIN:
      return self.train_data.shape[0]
    elif self._status == self.VALIDATE:
      return self.val_data.shape[0]
    else:
      return self.test_data.shape[0]
  
  def train(self):
    self._status = self.TRAIN

  def validate(self):
    self._status = self.VALIDATE

  def test(self):
    # if self._status != self.TEST:
    #   del self.train_data
    #   del self.train_label
    #   del self.val_data
    #   del self.val_label
    self._status = self.TEST


In [None]:
# config.py
import os
import numpy as np
import time
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch
from torch import nn
import torch.optim as optim
import torch.nn.functional as F

!mkdir log > /dev/null

class Config(object):
    def __init__(self, model_name):
        self.batch_size = 32
        self.test_batch_size = 10
        self.max_epoch = 20
        self.class_num = 100
        self.learning_rate = 0.003
        self.drop_rate = 0.1
        self.test_epoch = 1
        self.save_path = './checkpoint'
        if not os.path.exists(self.save_path):
            os.mkdir(self.save_path)
        self.model_name = model_name
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.aug = MyAugmentation()

    def logging(self, s, print_=True, log_=True):
        if print_:
            print(s)
        if log_:
            with open(os.path.join(os.path.join("log", self.model_name)), 'a+') as f_log:
                f_log.write(s + '\n')

    def train(self, model, optimizer, criterion, dataset):
        self.logging("Training Started, using " + str(criterion) + " and " + str(optimizer) + "\n\n")
        losses = []
        best_acc = 0.0
        best_epoch = 0
        model.train()
        dataset.train()
        train_loader = DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)

        for epoch in range(self.max_epoch):
            train_loss = 0.0
            start_time = time.time()
            train_total = train_correct = 0

            for data,target in train_loader:
                data = self.aug(data).to(self.device)
                target = target.to(self.device)
                optimizer.zero_grad()
                output = model(data) #得到预测值

                loss = criterion(output,target)
                loss.backward()
                optimizer.step()

                train_loss += loss.item()*data.size(0)

                pred = torch.argmax(output, dim=-1).cpu()
                train_total += len(pred)
                train_correct += int(torch.sum(pred == target.squeeze(-1).cpu()))

            train_loss = train_loss / len(train_loader.dataset)
            losses.append(train_loss)
            if train_total > 0:
                acc = train_correct / train_total
            else:
                acc = 0.0
            self.logging('[Epoch {:d}] Training Loss: {:.6f}; Time: {:.2f}min; Acc: {:.4f}'.format(
            epoch + 1, train_loss, (time.time()-start_time)/60, acc))
            if (epoch + 1) % self.test_epoch == 0:
                self.logging('-' * 70)
                eval_start_time = time.time()
                model.eval()
                dataset.validate()
                val_acc = self.test(model, dataset)
                self.logging('Validate\nTime: {:.2f}min; Accuracy: {:.4f}'.format(
                    (time.time()-eval_start_time)/60, val_acc))
                if val_acc > best_acc:
                    best_acc = val_acc
                    best_epoch = epoch + 1
                    torch.save(model.state_dict(),os.path.join(self.save_path, self.model_name))
                model.train()
                dataset.train()
                self.logging('-' * 70)
        self.logging('Training finished!')
        self.logging('Best epoch: {:d} | acc: {:.4f}'.format(best_epoch, best_acc))
        return losses
    
    def test(self, model, dataset):
        data_loader = DataLoader(dataset=dataset, batch_size=self.test_batch_size, shuffle=False)
        total = correct = 0
        for data, target in data_loader:
                data = data.to(self.device)
                target = target.to(self.device)
                output = model(data).cpu()
                # convert output probabilities to predicted class
                pred = torch.argmax(output, dim=-1)
                # compare predictions to true label
                target = target.squeeze(-1).cpu()
                total += len(pred)
                correct += int(torch.sum(pred == target))
        
        if total > 0:
            accuracy = correct/total
        else:
            accuracy = 0.0
        return accuracy

    def predict(self, model, test_loader):
        model.eval() # prep model for *evaluation*
        predictions = None

        for data in test_loader:
                data = data.to(self.device)
                output = model(data)
                # convert output probabilities to predicted class
                _, pred = torch.max(output, 1)
                if predictions == None:
                    predictions = pred
                else:
                    predictions = torch.cat((predictions, pred))
        return predictions

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
!pip3 install -U timm > /dev/null
import timm
from pprint import pprint

class OurModel(nn.Module):
  def __init__(self):
    super(OurModel, self).__init__()
    self.vit = timm.create_model('vit_base_patch16_224', pretrained=True)
    self.linear = nn.Linear(1000,100)
  def forward(self,input):
    return self.linear(self.vit(F.interpolate(input, (224, 224))))


In [None]:
import argparse
import torch
from torch import nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

parser = argparse.ArgumentParser()
parser.add_argument('--train_loc', type = str, default = './TinyImageNet/train/')
parser.add_argument('--val_loc', type = str, default = './TinyImageNet/val/')
parser.add_argument('--test_loc', type = str, default = './TinyImageNet/test/')
parser.add_argument('--model_name', type = str, default = "Temp")
args = parser.parse_known_args()[0]

dataset = TINDataset(args)

In [None]:
conf = Config(args.model_name)
model = OurModel()
model.to(device)

# 炼丹炉 The Alchemy
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(params = model.parameters(), lr=conf.learning_rate, momentum=0.9)

In [None]:
losses = conf.train(model, optimizer, criterion, dataset)

In [None]:
model.load_state_dict(torch.load('./checkpoint/Temp'))
model = model.to(device)

In [None]:
# submit to Kaggle
dataset.test()
test_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=10)
result = conf.predict(model, test_loader)

file = open("./submit.csv", mode="w")
file.write("Id,Category\n")
for i in range(len(result)):
  file.write(str(i) + ".jpg," + str(int(result[i])) + "\n")
file.flush()
file.close()

message = input("Input submission message: ").strip()
!kaggle competitions submit -f submit.csv -m $message deep-learning-thu-2020