In [46]:
import torch
import torchvision
import pandas as pd
import numpy as np
from tqdm import tqdm
import os
import shutil
import random

In [2]:
data_dir = './'

## Data cleaning

In [8]:
if not os.path.exists('./train'):
    os.mkdir('./train')
if not os.path.exists('./vali'):
    os.mkdir('./vali')
if not os.path.exists('./train_vali'):
    os.mkdir('./train_vali')

In [9]:
with open(os.path.join(data_dir, 'train_cultivar_mapping.csv'), 'r') as f:
    train_cultivar_mapping = pd.read_csv(f)
    labels = train_cultivar_mapping['cultivar'].unique()[:-1]
for i in labels:
    if not os.path.exists(os.path.join('./train', i)):
        os.mkdir(os.path.join('./train', i))
    if not os.path.exists(os.path.join('./vali', i)):
        os.mkdir(os.path.join('./vali', i))
    if not os.path.exists(os.path.join('./train_vali', i)):
        os.mkdir(os.path.join('./train_vali', i))

In [47]:
with open('./train_cultivar_mapping.csv', 'r') as f:
    lines = f.readlines()[1:]
    dict = {}
    for i in lines:
        tmp = i.split(',')
        dict[tmp[0]] = tmp[1][:-1]

In [54]:
# for i in os.listdir(os.path.join(data_dir, 'sorghum-id-fgvc-9/train_images')):
#     file_name = dict[i]
#     shutil.copy(os.path.join(data_dir, 'sorghum-id-fgvc-9/train_images', i), os.path.join('./train_vali', file_name))

In [66]:
# for i in os.listdir(os.path.join(data_dir, 'train_vali')):
#     if not i.startswith('.'):
#         category = os.listdir(os.path.join(data_dir, 'train_vali', i))
#         vali = random.sample(category, int(len(category) * 0.1))
#         for j in vali:
#             shutil.copy(os.path.join(data_dir, 'train_vali', i, j), os.path.join(data_dir, 'vali', i))
#         for j in category:
#             if j not in vali:
#                 shutil.copy(os.path.join(data_dir, 'train_vali', i, j), os.path.join(data_dir, 'train', i))

## Training

In [55]:
transform = {
    'train': torchvision.transforms.Compose([
        torchvision.transforms.Resize((512, 512)),
        torchvision.transforms.CenterCrop((224, 224)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
    ]),
    'vali': torchvision.transforms.Compose([
        torchvision.transforms.Resize((224, 224)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
    ])
}

In [56]:
train_dataset = torchvision.datasets.ImageFolder(root=os.path.join(data_dir, 'train'), transform=transform['train'])
vali_dataset = torchvision.datasets.ImageFolder(root=os.path.join(data_dir, 'vali'), transform=transform['vali'])

In [57]:
batch_size = 32

In [58]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)
vali_loader = torch.utils.data.DataLoader(vali_dataset, batch_size=batch_size, shuffle=False, num_workers=8)

In [59]:
model = torchvision.models.resnet50(pretrained=True, progress=True)
for parm in model.parameters():
    parm.requires_grad = False
model.fc = torch.nn.Linear(torchvision.models.resnet50().fc.in_features, 100)

In [60]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [61]:
loss_func = torch.nn.CrossEntropyLoss()

In [62]:
epochs = 1
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
model.to(device)
best_acc = 0
print('Training...')
for epoch in range(epochs):
    model.train()
    loss_count = 0
    train_bar = tqdm(train_loader, file=sys.stdout)
    for step, batch in enumerate(train_bar):
        img, label = batch
        optimizer.zero_grad()
        img.to(device)
        label.to(device)
        output = model(img)
        loss = loss_func(output, label)
        loss_count += loss.item()
        loss.backward()
        optimizer.step()
        train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss)

    model.eval()
    acc = 0
    with torch.no_grad():
        val_bar = tqdm(vali_loader, file=sys.stdout)
        for val_data in val_bar:
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))
            predict_y = torch.max(outputs, dim=1)[1]
            acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

            val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1, epochs)
    val_accurate = acc / len(vali_dataset)
    print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' % (epoch + 1, loss_count / len(train_loader), val_accurate))

    if val_accurate > best_acc:
        best_acc = val_accurate
        torch.save(model.state_dict(), './FGVC_model.pth')

print('Best Acc: {:.4f}'.format(best_acc))
print('Training Finished!')

Training...
train epoch[1/1] loss:4.012:  37%|███▋      | 230/626 [29:50<52:10,  7.90s/it]  