In [2]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader,random_split
from torchvision import transforms, datasets, models
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm

In [3]:
model = models.resnet18(weights=None)
model.fc=nn.Linear(model.fc.in_features,120)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [4]:
mean = [0.4914, 0.4822, 0.4465] 
std = [0.2470, 0.2435, 0.2616] 
batch_size = 64
n_epochs = 100

train_transform = transforms.Compose([ 
transforms.Resize((224,224)), 
transforms.RandomCrop(224, padding=4), 
transforms.RandomHorizontalFlip(),
transforms.ToTensor(), 
transforms.Normalize(mean, std)
])

path='C:/Users/User/DeepLearning/Deep_Learning/Bird_Classification/Images'
all_train = datasets.ImageFolder(root = path, transform = train_transform)
train_size = int(0.9 * len(all_train))
validation_size = len(all_train) - train_size
train_dataset, validation_dataset = random_split(all_train , [train_size, validation_size])

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=3
)
val_loader = DataLoader(
    validation_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=3
)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,T_max=20,eta_min=1e-9)

class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = np.inf

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False
    
def train(model, train_loader, optimizer, loss_fn):
    model.train()
    train_loss = 0.
    corrects=0
    total = 0
    for images, labels in tqdm(train_loader):
        optimizer.zero_grad() # step 1
        images = images.to(device)
        labels = labels.to(device)
        

        logits = model(images) # step 2 (forward pass)
        loss = loss_fn(logits, labels) # step 3 (compute loss)
        _, predictions = torch.max(logits, dim=1)
        corrects += predictions.eq(labels).sum().item()
        total += labels.size(0)
        
        loss.backward() # step 4 (backpropagation)
        optimizer.step()

        train_loss += loss.item()*images.size(0)
       
        
    train_loss = train_loss/len(train_loader.sampler)
    
    return train_loss, corrects/total    


@torch.no_grad()
def validate(model, valid_loader, loss_fn):
    model.eval()
    losses=0.
    corrects=0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(val_loader):
            images = images.to(device)
            labels = labels.to(device)
            
                
            logits = model(images) # step 2 (forward pass)
            loss = loss_fn(logits, labels) # step 3 (compute loss)
            total += labels.size(0)
            
            _, predictions = torch.max(logits, dim=1)
            corrects += predictions.eq(labels).sum().item()
            
            losses += loss.item()*images.size(0)    
            
        valid_loss = losses/len(valid_loader.sampler)
    return valid_loss, corrects / total

train_loss_list = []
valid_loss_list = []

early_stopper = EarlyStopper(patience=7)

for epoch in range(n_epochs):
    training_loss, training_accuracy = train(model, train_loader, optimizer, loss_fn)
    valid_loss, valid_accuracy = validate(model, val_loader, loss_fn)
    
    train_loss_list.append(training_loss)
    valid_loss_list.append(valid_loss)

    # if scheduler is not None and is_valid_available:
    #     scheduler.step(valid_loss)
    # elif scheduler is not None:
    scheduler.step()

    print(f"Epoch {epoch+1}/{n_epochs}: training accuracy: {training_accuracy}, valid accuracy: {valid_accuracy}")
    
    if early_stopper.early_stop(valid_loss): 
        break


100%|██████████| 290/290 [01:20<00:00,  3.62it/s]
100%|██████████| 33/33 [00:13<00:00,  2.44it/s]


Epoch 1/100: training accuracy: 0.028668610301263362, valid accuracy: 0.034499514091350825


100%|██████████| 290/290 [00:48<00:00,  5.94it/s]
100%|██████████| 33/33 [00:08<00:00,  4.12it/s]


Epoch 2/100: training accuracy: 0.06084656084656084, valid accuracy: 0.06462585034013606


100%|██████████| 290/290 [00:48<00:00,  6.01it/s]
100%|██████████| 33/33 [00:08<00:00,  4.07it/s]


Epoch 3/100: training accuracy: 0.09782960803368966, valid accuracy: 0.0456754130223518


100%|██████████| 290/290 [00:47<00:00,  6.04it/s]
100%|██████████| 33/33 [00:07<00:00,  4.15it/s]


Epoch 4/100: training accuracy: 0.14215527480833604, valid accuracy: 0.11273080660835763


100%|██████████| 290/290 [00:47<00:00,  6.05it/s]
100%|██████████| 33/33 [00:08<00:00,  4.01it/s]


Epoch 5/100: training accuracy: 0.1871828096317892, valid accuracy: 0.19387755102040816


100%|██████████| 290/290 [00:48<00:00,  6.01it/s]
100%|██████████| 33/33 [00:07<00:00,  4.15it/s]


Epoch 6/100: training accuracy: 0.24263038548752835, valid accuracy: 0.18999028182701652


100%|██████████| 290/290 [00:48<00:00,  6.01it/s]
100%|██████████| 33/33 [00:08<00:00,  4.04it/s]


Epoch 7/100: training accuracy: 0.2948385703487744, valid accuracy: 0.20456754130223517


100%|██████████| 290/290 [00:48<00:00,  5.98it/s]
100%|██████████| 33/33 [00:08<00:00,  4.08it/s]


Epoch 8/100: training accuracy: 0.3463448871612137, valid accuracy: 0.29397473275024294


100%|██████████| 290/290 [00:48<00:00,  6.00it/s]
100%|██████████| 33/33 [00:08<00:00,  4.08it/s]


Epoch 9/100: training accuracy: 0.3921822697332901, valid accuracy: 0.30466472303207


100%|██████████| 290/290 [00:47<00:00,  6.04it/s]
100%|██████████| 33/33 [00:08<00:00,  4.05it/s]


Epoch 10/100: training accuracy: 0.4453082820429759, valid accuracy: 0.3275024295432459


100%|██████████| 290/290 [00:48<00:00,  6.02it/s]
100%|██████████| 33/33 [00:07<00:00,  4.15it/s]


Epoch 11/100: training accuracy: 0.4912536443148688, valid accuracy: 0.35082604470359574


100%|██████████| 290/290 [00:48<00:00,  6.01it/s]
100%|██████████| 33/33 [00:07<00:00,  4.14it/s]


Epoch 12/100: training accuracy: 0.5422200626282259, valid accuracy: 0.3935860058309038


100%|██████████| 290/290 [00:48<00:00,  6.01it/s]
100%|██████████| 33/33 [00:08<00:00,  4.08it/s]


Epoch 13/100: training accuracy: 0.5858438613540654, valid accuracy: 0.4042759961127308


100%|██████████| 290/290 [00:48<00:00,  5.99it/s]
100%|██████████| 33/33 [00:08<00:00,  4.11it/s]


Epoch 14/100: training accuracy: 0.6350826044703596, valid accuracy: 0.42517006802721086


100%|██████████| 290/290 [00:49<00:00,  5.82it/s]
100%|██████████| 33/33 [00:08<00:00,  3.71it/s]


Epoch 15/100: training accuracy: 0.6769247381492279, valid accuracy: 0.4368318756073858


100%|██████████| 290/290 [00:55<00:00,  5.27it/s]
100%|██████████| 33/33 [00:08<00:00,  3.68it/s]


Epoch 16/100: training accuracy: 0.7193067703271785, valid accuracy: 0.4577259475218659


100%|██████████| 290/290 [00:55<00:00,  5.25it/s]
100%|██████████| 33/33 [00:08<00:00,  3.68it/s]


Epoch 17/100: training accuracy: 0.7506208832739445, valid accuracy: 0.478134110787172


100%|██████████| 290/290 [00:55<00:00,  5.26it/s]
100%|██████████| 33/33 [00:09<00:00,  3.66it/s]


Epoch 18/100: training accuracy: 0.7754562142317244, valid accuracy: 0.4752186588921283


100%|██████████| 290/290 [00:54<00:00,  5.30it/s]
100%|██████████| 33/33 [00:08<00:00,  3.75it/s]


Epoch 19/100: training accuracy: 0.7960803368966635, valid accuracy: 0.48299319727891155


100%|██████████| 290/290 [00:54<00:00,  5.27it/s]
100%|██████████| 33/33 [00:08<00:00,  3.72it/s]


Epoch 20/100: training accuracy: 0.8010474030882194, valid accuracy: 0.48882410106899904


100%|██████████| 290/290 [00:55<00:00,  5.26it/s]
100%|██████████| 33/33 [00:08<00:00,  3.72it/s]


Epoch 21/100: training accuracy: 0.8084440125256451, valid accuracy: 0.4834791059280855


100%|██████████| 290/290 [00:55<00:00,  5.25it/s]
100%|██████████| 33/33 [00:08<00:00,  3.76it/s]


Epoch 22/100: training accuracy: 0.8058524997300508, valid accuracy: 0.4878522837706511


100%|██████████| 290/290 [00:55<00:00,  5.26it/s]
100%|██████████| 33/33 [00:08<00:00,  3.74it/s]


Epoch 23/100: training accuracy: 0.8062844185293165, valid accuracy: 0.4815354713313897


100%|██████████| 290/290 [00:55<00:00,  5.26it/s]
100%|██████████| 33/33 [00:08<00:00,  3.70it/s]


Epoch 24/100: training accuracy: 0.7982399308929922, valid accuracy: 0.47862001943634597


100%|██████████| 290/290 [00:55<00:00,  5.26it/s]
100%|██████████| 33/33 [00:08<00:00,  3.71it/s]


Epoch 25/100: training accuracy: 0.7937047835007018, valid accuracy: 0.478134110787172


100%|██████████| 290/290 [00:55<00:00,  5.27it/s]
100%|██████████| 33/33 [00:08<00:00,  3.76it/s]

Epoch 26/100: training accuracy: 0.7948925601986826, valid accuracy: 0.4586977648202138





In [5]:
torch.save(model.state_dict(),'RestNet18_dog_model.pt')