该 Notebook 是在 kaggle 进行的测试, 因此数据路径的格式为 `/kaggle/input/...`.

In [116]:
# model
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models
from tqdm import tqdm

# dataset
import os
import math
import glob
import pandas as pd
from PIL import Image
from sklearn.model_selection import train_test_split
from torchvision import transforms
from torch.utils.data import Dataset, Subset, DataLoader
import matplotlib.pyplot as plt

# save result
import pickle

In [117]:
torch.manual_seed(0)
try:
    device = torch.device("mps")
except:
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

print(f'Current Device : {device}')

Current Device : cuda


In [128]:
import numpy as np
from skimage import io
from skimage import transform
import matplotlib.pyplot as plt
import os 
import torch
from torch.utils.data import Dataset,DataLoader
from torchvision.transforms import transforms
from PIL import Image

channel_mean = torch.Tensor([0.485, 0.456, 0.406])
channel_std = torch.Tensor([0.229, 0.224, 0.225])

class MyDataset(Dataset):
    def __init__(self,root_dir,names_file,transform=None):
        self.root_dir = root_dir #根目录
        self.names_file = names_file #.txt文件路径
        self.transform = transform #数据预处理
        self.size = 0 #数据集大小
        self.names_list = [] #数据集路径列表
        self.trans=transforms.Compose([
            transforms.CenterCrop(224),
            transforms.ToTensor(),
        ])
        
        if not os.path.isfile(self.names_file):
            print(self.names_file + 'does not exist!')
        file = open(self.names_file)
        for f in file: #循环读取.txt文件总每行数据信息
            self.names_list.append(f)
            self.size += 1
        
    def __len__(self):
        return self.size
    
    def __getitem__(self,index):
        image_path = self.root_dir + self.names_list[index].split(' ')[0] #获取图片数据路径
        if not os.path.isfile(image_path):
            print(image_path + 'does not exist!')
            return None
        image = io.imread(image_path) #读取图片
        image = Image.fromarray(image).resize((224, 224))
        image=image.convert('RGB')
        image=self.trans(image)
        label = int(self.names_list[index].split(' ')[1]) #读取标签
 
        return image,label
        
        #sample = {'image':image,'label':lable}
        #if self.transform:
        #    sample = self.transform(sample) 
            
        #return sample #返回图片及对应的标签

root_dir='/kaggle/input/ftd-time/FTD/train'
names_file='/kaggle/input/ftd-time/FTD/train/train.txt'

train_dataset = MyDataset(root_dir=root_dir,names_file=names_file,transform=None)
train_dataloader = DataLoader(dataset=train_dataset,batch_size=4,shuffle=True,num_workers=2)

root_dir='/kaggle/input/ftd-time/FTD/val'
names_file='/kaggle/input/ftd-time/FTD/val/val.txt'

valid_dataset = MyDataset(root_dir=root_dir,names_file=names_file,transform=None)
valid_dataloader = DataLoader(dataset=valid_dataset,batch_size=4,shuffle=False,num_workers=2)

## Build Model

In [129]:
class PretrainViT(nn.Module):

    def __init__(self):
        super(PretrainViT, self).__init__()
        model = models.vit_l_16(pretrained=True)
        num_classifier_feature = model.heads.head.in_features
        model.heads.head = nn.Sequential(
            nn.Linear(num_classifier_feature, 256),
            nn.ReLU(),
#             nn.BatchNorm2d(256),
            nn.Linear(256, 2)
        )
        self.model = model

        for param in self.model.named_parameters():
            if "heads" not in param[0]:
                param[1].requires_grad = False

    def forward(self, x):
        return self.model(x)

In [130]:
net = PretrainViT()

print(f"number of paramaters: {sum([param.numel() for param in net.parameters() if param.requires_grad])}")

number of paramaters: 262914


In [131]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)

PretrainViT(
  (model): VisionTransformer(
    (conv_proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
    (encoder): Encoder(
      (dropout): Dropout(p=0.0, inplace=False)
      (layers): Sequential(
        (encoder_layer_0): EncoderBlock(
          (ln_1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
          (self_attention): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)
          )
          (dropout): Dropout(p=0.0, inplace=False)
          (ln_2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
          (mlp): MLPBlock(
            (linear_1): Linear(in_features=1024, out_features=4096, bias=True)
            (act): GELU()
            (dropout_1): Dropout(p=0.0, inplace=False)
            (linear_2): Linear(in_features=4096, out_features=1024, bias=True)
            (dropout_2): Dropout(p=0.0, inplace=False)
          )
        )
        (encoder_layer_1): EncoderBlock(


## Train Model

### Loss Function and Optimizer

In [132]:
criterion = nn.CrossEntropyLoss()
# optimizer = optim.SGD(net.parameters(), lr=0.009, momentum=0.9)
optimizer = optim.AdamW(net.parameters(), lr=0.009, weight_decay=5e-4)

### Training Loop 

In [133]:
def get_accuracy(output, label):
    output = output.to("cpu")
    label = label.to("cpu")

    sm = F.softmax(output, dim=1)
    _, index = torch.max(sm, dim=1)
    return torch.sum((label == index)) / label.size()[0]

In [134]:
def train(model, dataloader):
    model.train()
    running_loss = 0.0
    total_loss = 0.0
    running_acc = 0.0
    total_acc = 0.0

    for batch_idx, (batch_img, batch_label) in enumerate(dataloader):

        batch_img = batch_img.to(device)
        batch_label = batch_label.to(device)
#         print(batch_img.shape)
        optimizer.zero_grad()
        output = net(batch_img)
        loss = criterion(output, batch_label)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        total_loss += loss.item()

        acc = get_accuracy(output, batch_label)
        running_acc += acc
        total_acc += acc

        if batch_idx % 100 == 0 and batch_idx != 0:
            print(f"[step: {batch_idx:4d}/{len(dataloader)}] loss: {running_loss / 100:.3f}")
            running_loss = 0.0
            running_acc = 0.0
    
    return total_loss / len(dataloader), total_acc / len(dataloader)

In [135]:
def validate(model, dataloader):
    model.eval()
    total_loss = 0.0
    total_acc = 0.0

    for batch_idx, (batch_img, batch_label) in enumerate(dataloader):

        batch_img = batch_img.to(device)
        batch_label = batch_label.to(device)

        # optimizer.zero_grad()
        output = net(batch_img)
        loss = criterion(output, batch_label)
        # loss.backward()
        # optimizer.step()

        total_loss += loss.item()
        acc = get_accuracy(output, batch_label)
        total_acc += acc
    
    return total_loss / len(dataloader), total_acc / len(dataloader)

In [136]:
EPOCHS = 50
train_loss_history = []
valid_loss_history = []

train_acc_history = []
valid_acc_history = []

for epoch in range(EPOCHS):
    train_loss, train_acc = train(net, train_dataloader)
    valid_loss, valid_acc = validate(net, valid_dataloader)
    print(f"Epoch: {epoch:2d}, training loss: {train_loss:.3f}, training acc: {train_acc:.3f} validation loss: {valid_loss:.3f}, validation acc: {valid_acc:.3f}")

    train_loss_history.append(train_loss)
    valid_loss_history.append(valid_loss)

    train_acc_history.append(train_acc)
    valid_acc_history.append(valid_acc)

    if valid_loss <= min(valid_loss_history):
        torch.save(net.state_dict(), "net.pt")

Epoch:  0, training loss: 1.540, training acc: 0.389 validation loss: 0.694, validation acc: 0.472
Epoch:  1, training loss: 0.700, training acc: 0.481 validation loss: 0.693, validation acc: 0.500
Epoch:  2, training loss: 0.697, training acc: 0.454 validation loss: 0.693, validation acc: 0.472
Epoch:  3, training loss: 0.692, training acc: 0.574 validation loss: 0.690, validation acc: 0.556
Epoch:  4, training loss: 0.704, training acc: 0.509 validation loss: 0.708, validation acc: 0.528
Epoch:  5, training loss: 0.702, training acc: 0.528 validation loss: 0.692, validation acc: 0.528
Epoch:  6, training loss: 0.692, training acc: 0.528 validation loss: 0.692, validation acc: 0.528
Epoch:  7, training loss: 0.693, training acc: 0.528 validation loss: 0.692, validation acc: 0.528
Epoch:  8, training loss: 0.695, training acc: 0.528 validation loss: 0.692, validation acc: 0.528
Epoch:  9, training loss: 0.695, training acc: 0.528 validation loss: 0.692, validation acc: 0.528
Epoch: 10,

## Predict on Test Dataset and Submit to Kaggle

In [137]:
net = PretrainViT()
net.load_state_dict(torch.load("./net.pt", map_location="cpu"))
net.to(device)
net.eval()

PretrainViT(
  (model): VisionTransformer(
    (conv_proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
    (encoder): Encoder(
      (dropout): Dropout(p=0.0, inplace=False)
      (layers): Sequential(
        (encoder_layer_0): EncoderBlock(
          (ln_1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
          (self_attention): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)
          )
          (dropout): Dropout(p=0.0, inplace=False)
          (ln_2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
          (mlp): MLPBlock(
            (linear_1): Linear(in_features=1024, out_features=4096, bias=True)
            (act): GELU()
            (dropout_1): Dropout(p=0.0, inplace=False)
            (linear_2): Linear(in_features=4096, out_features=1024, bias=True)
            (dropout_2): Dropout(p=0.0, inplace=False)
          )
        )
        (encoder_layer_1): EncoderBlock(


In [142]:
root_dir='/kaggle/input/ftd-time/FTD/test'
names_file='/kaggle/input/ftd-time/FTD/test/test.txt'

test_dataset = MyDataset(root_dir=root_dir,names_file=names_file,transform=None)
test_dataloader = DataLoader(dataset=test_dataset,batch_size=4,shuffle=False,num_workers=2)

In [143]:
validate(net, test_dataloader)

(0.6824165999889373, tensor(0.5750))