In [2]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import models

from PIL import Image
import numpy as np
import os
import time
import warnings
import torch.multiprocessing as mp

In [None]:
if __name__ == "__main__":
    try:
        mp.set_start_method("spawn", force=True)
    except RuntimeError:
        pass

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Используется устройство: {device}")

Используется устройство: cpu


In [4]:
class CatDogDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.transform = transform
        self.image_paths = []
        self.labels = []

        class_map = {"cats": 0, "dogs": 1}

        for class_name, label in class_map.items():
            class_dir = os.path.join(root_dir, class_name)
            if not os.path.exists(class_dir):
                continue

            for img in os.listdir(class_dir):
                if img.lower().endswith((".jpg", ".jpeg", ".png")):
                    self.image_paths.append(os.path.join(class_dir, img))
                    self.labels.append(label)

        print(f"{root_dir}: {len(self.image_paths)} images")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = np.array(Image.open(self.image_paths[idx]).convert("RGB"))
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image=image)["image"]

        return image, label

In [5]:
def get_transforms(augment=False):
    if augment:
        return A.Compose([
            A.Resize(224, 224),
            A.HorizontalFlip(p=0.5),
            A.Rotate(limit=30, p=0.5),
            A.ColorJitter(0.2, 0.2, 0.2, 0.1, p=0.5),
            A.Normalize(
                mean=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225)
            ),
            ToTensorV2()
        ])
    else:
        return A.Compose([
            A.Resize(224, 224),
            A.Normalize(
                mean=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225)
            ),
            ToTensorV2()
        ])


def create_model(num_classes=2):
    model = models.resnet18(
        weights=models.ResNet18_Weights.IMAGENET1K_V1
    )

    for param in model.parameters():
        param.requires_grad = False

    for param in model.layer4.parameters():
        param.requires_grad = True

    num_ftrs = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Linear(num_ftrs, 512),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(512, num_classes)
    )

    return model.to(device)


def train_model(model, dataloaders, criterion, optimizer, num_epochs=5):
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")

        for phase in ["train", "val"]:
            model.train() if phase == "train" else model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs, labels = inputs.to(device), labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == "train"):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == "train":
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            print(f"{phase}: Loss={epoch_loss:.4f}, Acc={epoch_acc:.4f}")

    return model


def get_subset(dataset, fraction):
    size = int(len(dataset) * fraction)
    indices = list(range(size))
    return Subset(dataset, indices)

In [6]:
def main():
    batch_sizes = [16, 32]
    dataset_fractions = {
        "small": 0.2,
        "medium": 0.5,
        "full": 1.0
    }

    for batch_size in batch_sizes:
        print(f"\n===============================")
        print(f"Batch size: {batch_size}")
        print(f"===============================")

        full_train_dataset = CatDogDataset(
            "cats_dogs/train",
            transform=get_transforms(augment=True)
        )
        val_dataset = CatDogDataset(
            "cats_dogs/val",
            transform=get_transforms(augment=False)
        )

        for name, frac in dataset_fractions.items():
            print(f"\n--- Dataset size: {name} ({int(frac * 100)}%) ---")

            train_subset = get_subset(full_train_dataset, frac)

            dataloaders = {
                "train": DataLoader(
                    train_subset,
                    batch_size=batch_size,
                    shuffle=True,
                    num_workers=0
                ),
                "val": DataLoader(
                    val_dataset,
                    batch_size=batch_size,
                    shuffle=False,
                    num_workers=0
                )
            }

            model = create_model()

            criterion = nn.CrossEntropyLoss()
            optimizer = optim.Adam(
                filter(lambda p: p.requires_grad, model.parameters()),
                lr=1e-4
            )

            train_model(model, dataloaders, criterion, optimizer, num_epochs=5)

    print("\n=== TESTING MODEL ===")

    model.eval()

    test_dataset = CatDogDataset(
        "cats_dogs/test",
        transform=get_transforms(augment=False)
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=32,
        shuffle=False,
        num_workers=0
    )

    correct, total = 0, 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            total += labels.size(0)
            correct += (preds == labels).sum().item()

    print(f"Test Accuracy: {100 * correct / total:.2f}%")

In [7]:
if __name__ == "__main__":
    main()



Batch size: 16
cats_dogs/train: 400 images
cats_dogs/val: 400 images

--- Dataset size: small (20%) ---

Epoch 1/5
train: Loss=0.6179, Acc=0.6750
val: Loss=0.8738, Acc=0.5000

Epoch 2/5
train: Loss=0.1253, Acc=1.0000
val: Loss=1.4023, Acc=0.5000

Epoch 3/5
train: Loss=0.0393, Acc=1.0000
val: Loss=1.8666, Acc=0.5000

Epoch 4/5
train: Loss=0.0144, Acc=1.0000


KeyboardInterrupt: 