In [1]:
import torch
from torchvision.models import resnet50, ResNet50_Weights
from kaggle.dataset.dataUtils import FlowerDataset
from torch.utils.data import DataLoader
from torchvision.transforms import Resize, ToTensor, Normalize, Compose
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def train_step(model, loss_fn, optimizer, dataloader):
    model.train()
    for epoch in range(EPOCH):
        print(f'Current Epoch is {epoch + 1}')
        for y, x in tqdm(dataloader):
            optimizer.zero_grad()
            pred = model(x)
            loss = loss_fn(pred, y)
            loss.backward()
            optimizer.step()
        print(f'Current loss = {loss.item()}')



def test_step(model, dataloader):
    correct = 0.
    model.eval()
    for y, x in dataloader:
        pred = model(x)
        correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    print(f'Total data {len(dataloader.dataset)}')
    print(f'Correct = {correct}')
    acc = correct/len(dataloader.dataset)
    print(f'The acc is {acc*100.}%')

In [3]:
DEVICE = 'cuda' if torch.cuda.is_available() else "cpu"
EPOCH = 20
BATCH_SIZE = 32
ROOT_PATH = './kaggle/dataset/flowers_/'
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

In [4]:
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2).to(DEVICE)

In [5]:
train_dataset = FlowerDataset(
    root_path=ROOT_PATH, 
    split='train', 
    transform=Compose([
        Resize(size=(150, 150)),
        ToTensor(),
        Normalize(mean, std)
        ]),
        device=DEVICE
    )
test_dataset = FlowerDataset(
        root_path=ROOT_PATH, 
        split='test',
        transform=Compose([
        Resize(size=(150, 150)),
        ToTensor(),
        Normalize(mean, std)
        ]),
        device=DEVICE)


In [6]:
trainDataloader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
testDataloader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [7]:
outputLayer = torch.nn.Linear(1000, 5)
model_s = torch.nn.Sequential(model, outputLayer).to(DEVICE)

In [8]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_s.parameters(), lr=1e-3)

In [9]:
train_step(model_s, loss_fn, optimizer, trainDataloader)
test_step(model_s, trainDataloader)
test_step(model_s, testDataloader)

Current Epoch is 1


100%|██████████| 95/95 [00:13<00:00,  7.22it/s]


Current loss = 0.5529111623764038
Current Epoch is 2


100%|██████████| 95/95 [00:10<00:00,  8.83it/s]


Current loss = 0.7531698942184448
Current Epoch is 3


100%|██████████| 95/95 [00:10<00:00,  8.85it/s]


Current loss = 0.07397636026144028
Current Epoch is 4


100%|██████████| 95/95 [00:10<00:00,  8.84it/s]


Current loss = 0.32712322473526
Current Epoch is 5


100%|██████████| 95/95 [00:10<00:00,  8.84it/s]


Current loss = 0.30748218297958374
Current Epoch is 6


100%|██████████| 95/95 [00:10<00:00,  8.83it/s]


Current loss = 0.8707674145698547
Current Epoch is 7


100%|██████████| 95/95 [00:10<00:00,  8.82it/s]


Current loss = 0.07193703204393387
Current Epoch is 8


100%|██████████| 95/95 [00:11<00:00,  8.49it/s]


Current loss = 0.05878368020057678
Current Epoch is 9


100%|██████████| 95/95 [00:11<00:00,  8.59it/s]


Current loss = 0.04986454173922539
Current Epoch is 10


100%|██████████| 95/95 [00:10<00:00,  8.80it/s]


Current loss = 0.011146982200443745
Current Epoch is 11


100%|██████████| 95/95 [00:10<00:00,  8.70it/s]


Current loss = 0.025362379848957062
Current Epoch is 12


100%|██████████| 95/95 [00:11<00:00,  8.63it/s]


Current loss = 0.06425347179174423
Current Epoch is 13


100%|██████████| 95/95 [00:11<00:00,  8.40it/s]


Current loss = 2.290820837020874
Current Epoch is 14


100%|██████████| 95/95 [00:10<00:00,  8.67it/s]


Current loss = 0.21491971611976624
Current Epoch is 15


100%|██████████| 95/95 [00:10<00:00,  8.74it/s]


Current loss = 0.2626623213291168
Current Epoch is 16


100%|██████████| 95/95 [00:10<00:00,  8.66it/s]


Current loss = 0.1239720806479454
Current Epoch is 17


100%|██████████| 95/95 [00:10<00:00,  8.68it/s]


Current loss = 0.2578270435333252
Current Epoch is 18


100%|██████████| 95/95 [00:10<00:00,  8.73it/s]


Current loss = 0.05555441603064537
Current Epoch is 19


100%|██████████| 95/95 [00:10<00:00,  8.71it/s]


Current loss = 0.0035025556571781635
Current Epoch is 20


100%|██████████| 95/95 [00:10<00:00,  8.68it/s]


Current loss = 0.0004833537677768618
Total data 3027
Correct = 3023.0
The acc is 99.86785596299967%
Total data 1296
Correct = 1053.0
The acc is 81.25%
