In [28]:
import torch
import torch.nn as nn

from PIL import Image
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet18
from torchvision import transforms
from tqdm import tqdm


In [2]:
DATASET_NAME = 'cats_vs_dogs'
datasets = load_dataset(DATASET_NAME)
datasets


DatasetDict({
    train: Dataset({
        features: ['image', 'labels'],
        num_rows: 23410
    })
})

# Split dataset

In [3]:
TEST_SIZE = 0.2
datasets = datasets['train'].train_test_split(test_size=TEST_SIZE)


# Build Dataloader

### Build transformss function top preproces images

In [21]:
IMG_SIZE = 64
img_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.485, 0.456, 0.406],
        [0.229, 0.224, 0.225]
    )
])


# CatDogDataset class

In [22]:
class CatDogDataset(Dataset):
    def __init__(self, data, transform):
        self.data = data
        self.transform = transform
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        images = self.data[idx]['image']
        labels = self.data[idx]['labels']
        
        if self.transform:
            images = self.transform(images)
        
        labels = torch.tensor(labels, dtype=torch.long)
    
        return images, labels


## Initialize DataLoader

In [23]:
TRAIN_BATCH_SIZE = 32
VAL_BATCH_SIZE = 32

train_dataset = CatDogDataset(datasets['train'], transform=img_transforms)
test_dataset = CatDogDataset(datasets['test'], transform=img_transforms)

train_loader = DataLoader(train_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=VAL_BATCH_SIZE, shuffle=False, num_workers=0)


# Build model

- build model with backbone is pre-trained resnet18 for classify cat and dog

In [24]:
class CatDogModel(nn.Module):
    def __init__(self, n_classes):
        super(CatDogModel, self).__init__()
        
        resnet_model = resnet18(weights='IMAGENET1K_V1')
        self.backbone = nn.Sequential(*list(resnet_model.children())[:-1])
        for param in self.backbone.parameters():
            param.requires_grad = False
        
        in_features = resnet_model.fc.in_features
        self.fc = nn.Linear(in_features, n_classes)
        
    def forward(self, x):
        x = self.backbone(x)
        x = torch.flatten(x, 1)
        x  = self.fc(x)
        
        return x


In [25]:
device = 'cuda' if torch.cuda.is_available() else "cpu"
N_CLASSES = 2

model = CatDogModel(N_CLASSES).to(device)
test_input = torch.rand(1, 3, 224, 224).to(device)

with torch.no_grad():
    output = model(test_input)
    print(output.shape) # (1, 2)


torch.Size([1, 2])


# train model


In [29]:
EPOCHS = 10
LR = 1e-3
WEIGHT_DECAY = 1e-5

optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
criterion = torch.nn.CrossEntropyLoss()

for epoch in range(EPOCHS):
    train_losses = []
    model.train()
    
    print(f"\n[Epoch {epoch + 1}] Training...")
    for images, labels in tqdm(train_loader, desc="Training", leave=False):
        images = images.to(device)
        labels = labels.to(device)
        
        outputs = model(images)
        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_losses.append(loss.item())
        
    train_losses = sum(train_losses) / len(train_losses)
    
    val_losses = []
    model.eval()
    print(f"[Epoch {epoch + 1}] Validating...")
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Validation", leave=False):
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            val_losses.append(loss.item())
            
    val_losses = sum(val_losses) / len(val_losses)
    print(f"EPOCH {epoch + 1}: Train loss: {train_losses:.3f}\tVal loss: {val_losses:.3f}")



[Epoch 1] Training...


                                                           

[Epoch 1] Validating...


                                                             

EPOCH 1: Train loss: 0.548	Val loss: 0.583

[Epoch 2] Training...


                                                           

[Epoch 2] Validating...


                                                             

EPOCH 2: Train loss: 0.539	Val loss: 0.538

[Epoch 3] Training...


                                                           

[Epoch 3] Validating...


                                                             

EPOCH 3: Train loss: 0.539	Val loss: 0.531

[Epoch 4] Training...


                                                           

[Epoch 4] Validating...


                                                             

EPOCH 4: Train loss: 0.545	Val loss: 0.522

[Epoch 5] Training...


                                                           

[Epoch 5] Validating...


                                                             

EPOCH 5: Train loss: 0.535	Val loss: 0.556

[Epoch 6] Training...


                                                           

[Epoch 6] Validating...


                                                             

EPOCH 6: Train loss: 0.535	Val loss: 0.536

[Epoch 7] Training...


                                                           

[Epoch 7] Validating...


                                                             

EPOCH 7: Train loss: 0.546	Val loss: 0.543

[Epoch 8] Training...


                                                           

[Epoch 8] Validating...


                                                             

EPOCH 8: Train loss: 0.548	Val loss: 0.545

[Epoch 9] Training...


                                                           

[Epoch 9] Validating...


                                                             

EPOCH 9: Train loss: 0.541	Val loss: 0.565

[Epoch 10] Training...


                                                           

[Epoch 10] Validating...


                                                             

EPOCH 10: Train loss: 0.551	Val loss: 0.536




In [30]:
SAVE_PATH = 'catdog_weights.pt'
torch.save(model.state_dict(), SAVE_PATH)
