In [None]:
import os
import shutil
import random
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from source.model import SimpleNet


In [None]:
SEED = 42
TRAIN_RATIO = 0.8
TEST_COUNT = 50
BATCH_SIZE = 32
LR = 0.001
EPOCHS = 5
IMG_SIZE = 128
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

random.seed(SEED)
torch.manual_seed(SEED)

def prepare_data(source_dir, target_dir):
    if os.path.exists(target_dir):
        return

    source_path = Path(source_dir)
    target_path = Path(target_dir)
    
    for split in ['train', 'val', 'test']:
        for cls in ['cat', 'dog']:
            (target_path / split / cls).mkdir(parents=True, exist_ok=True)
            
    files = list(source_path.glob('*.jpg'))
    cats = [f for f in files if 'cat' in f.name]
    dogs = [f for f in files if 'dog' in f.name]
    
    for category_files, cls in [(cats, 'cat'), (dogs, 'dog')]:
        random.shuffle(category_files)
        test_files = category_files[:TEST_COUNT]
        remaining = category_files[TEST_COUNT:]
        split_idx = int(len(remaining) * TRAIN_RATIO)
        train_files = remaining[:split_idx]
        val_files = remaining[split_idx:]
        
        for f in test_files: shutil.copy(f, target_path / 'test' / cls / f.name)
        for f in train_files: shutil.copy(f, target_path / 'train' / cls / f.name)
        for f in val_files: shutil.copy(f, target_path / 'val' / cls / f.name)

prepare_data('train', 'dataset')

In [None]:
transform_train = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

transform_val = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_ds = datasets.ImageFolder('dataset/train', transform=transform_train)
val_ds = datasets.ImageFolder('dataset/val', transform=transform_val)
test_ds = datasets.ImageFolder('dataset/test', transform=transform_val)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

In [None]:
model = SimpleNet().to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)

print(f"Training on {DEVICE}")

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    for inputs, labels in pbar:
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        pbar.set_postfix({'loss': running_loss/len(train_loader), 'acc': 100.*correct/total})
        
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item()
            _, predicted = outputs.max(1)
            val_total += labels.size(0)
            val_correct += predicted.eq(labels).sum().item()
            
    print(f"Val Loss: {val_loss/len(val_loader):.4f} | Val Acc: {100.*val_correct/val_total:.2f}%")

torch.save(model.state_dict(), 'model.pth')

In [None]:
model.eval()
test_correct = 0
test_total = 0

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        outputs = model(inputs)
        _, predicted = outputs.max(1)
        test_total += labels.size(0)
        test_correct += predicted.eq(labels).sum().item()

print(f"Test Accuracy: {100.*test_correct/test_total:.2f}%")