<h3>Import</h3>

In [56]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision as tv

import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
import shutil

<h3>Data set</h3>

In [57]:
class Dataset2class(torch.utils.data.Dataset):
    def __init__(self, path_dir1:str, path_dir2:str, validate_images:bool=False, validate_sample_size:int=100, remove_bad:bool=False, bad_dir: str = None):
        super().__init__()
        
        self.path_dir1 = path_dir1
        self.path_dir2 = path_dir2
        self.remove_bad = remove_bad
        self.bad_dir = bad_dir
        
        allowed_ext = ('.jpg', '.jpeg', '.png', '.bmp')
        # Quick filter by extension and non-zero file size (fast)
        dir1_files = sorted([f for f in os.listdir(path_dir1)
                             if f.lower().endswith(allowed_ext) 
                             and os.path.isfile(os.path.join(path_dir1, f))
                             and os.path.getsize(os.path.join(path_dir1, f)) > 0])
        dir2_files = sorted([f for f in os.listdir(path_dir2)
                             if f.lower().endswith(allowed_ext) 
                             and os.path.isfile(os.path.join(path_dir2, f))
                             and os.path.getsize(os.path.join(path_dir2, f)) > 0])

        # If full validation requested, try reading images (slower)
        def _filter_valid_files(dir_path, files):
            valid, removed = [], []
            for fn in files:
                p = os.path.join(dir_path, fn)
                try:
                    img = cv2.imread(p, cv2.IMREAD_COLOR)
                    if img is None or img.size == 0:
                        removed.append(fn)
                    else:
                        valid.append(fn)
                except Exception:
                    removed.append(fn)
            return valid, removed

        if validate_images:
            self.dir1_list, removed1 = _filter_valid_files(path_dir1, dir1_files)
            self.dir2_list, removed2 = _filter_valid_files(path_dir2, dir2_files)
            self.removed_files_summary = {path_dir1: removed1, path_dir2: removed2}
            self.validated = True
        else:
            self.dir1_list, self.dir2_list = dir1_files, dir2_files
            self.removed_files_summary = {path_dir1: [], path_dir2: []}
            self.validated = False
        
        # runtime-removed files recorded here
        self.runtime_removed = {path_dir1: [], path_dir2: []}
        
        if len(self.dir1_list) == 0 or len(self.dir2_list) == 0:
            raise RuntimeError(f"No valid images found in {path_dir1} or {path_dir2} (checked extensions {allowed_ext})")
        
    def __getitem__(self, idx):
        # Defensive normalization: if idx is outside current dataset (due to prior removals),
        # wrap it into the current length so DataLoader's precomputed indices won't break.
        total_len = len(self)
        if total_len == 0:
            raise RuntimeError("Dataset is empty (all files removed or invalid).")
        if idx < 0:
            idx = idx % total_len
        if idx >= total_len:
            idx = idx % total_len

        # Try up to total_len times to find a readable image (avoids infinite loop)
        attempts = 0
        max_attempts = total_len
        while attempts < max_attempts:
            total_len = len(self)
            if total_len == 0:
                raise RuntimeError("Dataset is empty (all files removed or invalid).")

            # Normalize index each loop because lists may have shrunk
            idx = idx % total_len

            if idx < len(self.dir1_list):
                class_id = 0
                local_idx = idx
                filename = self.dir1_list[local_idx]
                img_path = os.path.join(self.path_dir1, filename)
                list_ref = self.dir1_list
                list_key = self.path_dir1
            else:
                class_id = 1
                local_idx = idx - len(self.dir1_list)
                # If local_idx somehow exceeds second list (possible after removals), normalize and retry
                if local_idx >= len(self.dir2_list):
                    idx = idx % total_len
                    attempts += 1
                    continue
                filename = self.dir2_list[local_idx]
                img_path = os.path.join(self.path_dir2, filename)
                list_ref = self.dir2_list
                list_key = self.path_dir2

            img = cv2.imread(img_path, cv2.IMREAD_COLOR)
            if img is None or img.size == 0:
                # record and optionally move the bad file; remove it from list and retry same idx
                self.runtime_removed[list_key].append(filename)
                try:
                    if self.remove_bad and self.bad_dir:
                        os.makedirs(self.bad_dir, exist_ok=True)
                        dest = os.path.join(self.bad_dir, f"{list_key.replace('/','_').replace(':','')}_{filename}")
                        try:
                            shutil.move(img_path, dest)
                        except Exception:
                            pass
                except Exception:
                    pass
                del list_ref[local_idx]
                attempts += 1
                continue

            # good image — preprocess and return
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = img.astype(np.float32) / 255.0
            img = cv2.resize(img, (64, 64), interpolation=cv2.INTER_AREA)
            img = img.transpose((2, 0, 1))
            
            t_img = torch.from_numpy(img)
            t_class_id = torch.tensor(class_id, dtype=torch.long)
            
            return {'img': t_img, 'label': t_class_id}

        # if we reach here, many attempts failed => dataset likely corrupted
        raise RuntimeError("Unable to fetch a readable image after multiple attempts; dataset may be corrupted.")
         
    def __len__(self):
        return len(self.dir1_list) + len(self.dir2_list)

In [58]:
train_cats_dir = './PetImages/Cat'
train_dogs_dir = './PetImages/Dog'

test_cats_dir = './PetImages/Cat'
test_dogs_dir = './PetImages/Dog'

# Enable remove_bad and set a bad_dir so problematic files are moved and won't stall later
train_ds_catsdogs = Dataset2class(train_cats_dir, train_dogs_dir, validate_images=False, remove_bad=True, bad_dir='./bad_images')
test_ds_catsdogs = Dataset2class(test_cats_dir, test_dogs_dir, validate_images=False, remove_bad=True, bad_dir='./bad_images')

print("Train size:", len(train_ds_catsdogs), "Test size:", len(test_ds_catsdogs))

Train size: 24997 Test size: 24997


<h3>Data Loader</h3>

In [59]:
batch_size = 16

train_loader = torch.utils.data.DataLoader(
    train_ds_catsdogs, shuffle=True, 
    batch_size=batch_size, num_workers=0, drop_last=True 
)
test_loader = torch.utils.data.DataLoader(
    test_ds_catsdogs, shuffle=True, 
    batch_size=batch_size, num_workers=0, drop_last=False 
)

In [60]:
# Quick dataloader sanity check: iterate a few batches and ensure it doesn't stall; then print runtime-removed
import time
def check_loader(loader, max_batches=5):
    start = time.time()
    for i, sample in enumerate(loader):
        print(f"Batch {i+1}: img {sample['img'].shape}, labels {sample['label'].shape}")
        if i+1 >= max_batches:
            break
    print("Elapsed (s):", time.time() - start)

check_loader(train_loader, max_batches=5)

# show runtime-removed files (if any)
print("Runtime removed (train):", train_ds_catsdogs.runtime_removed)
print("Runtime removed (test):", test_ds_catsdogs.runtime_removed)

Batch 1: img torch.Size([16, 3, 64, 64]), labels torch.Size([16])
Batch 2: img torch.Size([16, 3, 64, 64]), labels torch.Size([16])
Batch 3: img torch.Size([16, 3, 64, 64]), labels torch.Size([16])
Batch 4: img torch.Size([16, 3, 64, 64]), labels torch.Size([16])
Batch 5: img torch.Size([16, 3, 64, 64]), labels torch.Size([16])
Elapsed (s): 0.1394639015197754
Runtime removed (train): {'./PetImages/Cat': [], './PetImages/Dog': []}
Runtime removed (test): {'./PetImages/Cat': [], './PetImages/Dog': []}


<h3>Architecture</h3>

In [61]:
# Сверточная нейронная сеть
class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.act = nn.LeakyReLU(0.2)
        self.maxpool = nn.MaxPool2d(2, 2)
        
        self.conv0 = nn.Conv2d(3, 32, 3, stride=1, padding=0)
        self.conv1 = nn.Conv2d(32, 32, 3, stride=1, padding=0)
        self.conv2 = nn.Conv2d(32, 32, 3, stride=1, padding=0)
        self.conv3 = nn.Conv2d(32, 32, 3, stride=1, padding=0)
        self.conv4 = nn.Conv2d(32, 32, 3, stride=1, padding=0)
        
        self.adaptivepool = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()
        # flattened features == number of channels after convs (32)
        self.linear1 = nn.Linear(32, 10)
        self.linear2 = nn.Linear(10, 2)
        

    def forward(self, x):
        out = self.conv0(x)
        out = self.act(out)
        out = self.maxpool(out)
        
        out = self.conv1(out)
        out = self.act(out)
        out = self.maxpool(out)
        
        out = self.conv2(out)
        out = self.act(out)
        out = self.maxpool(out)
        
        out = self.conv3(out)
        out = self.act(out)

        out = self.conv4(out)
        out = self.act(out)
        
        out = self.adaptivepool(out)
        out = self.flatten(out)
        out = self.linear1(out)
        out = self.act(out)
        out = self.linear2(out)
        
        
        return out

In [62]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [63]:
model = ConvNet()

In [64]:
model

ConvNet(
  (act): LeakyReLU(negative_slope=0.2)
  (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
  (adaptivepool): AdaptiveAvgPool2d(output_size=(1, 1))
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear1): Linear(in_features=32, out_features=10, bias=True)
  (linear2): Linear(in_features=10, out_features=2, bias=True)
)

In [65]:
count_parameters(model)

38240

In [66]:
for sample in train_loader:
    img = sample['img']
    label = sample['label']
    model(img)
    break

In [67]:
img.shape

torch.Size([16, 3, 64, 64])

<h3>Optimizer</h3>

In [68]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.999))

In [69]:
def accuracy(pred, label):
    # pred: logits, label: class indices (LongTensor)
    pred_cls = pred.detach().argmax(dim=1)
    true_cls = label.detach().argmax(dim=1) if label.dim() > 1 else label.detach()
    return (pred_cls == true_cls).float().mean().item()

<h3>Learning cycle</h3>

In [70]:
epochs = 10

for epoch in range(epochs):
    loss_val = 0.0
    acc_val = 0.0
    skipped_batches = 0
    for sample in (pbar := tqdm(train_loader)):
        try:
            optimizer.zero_grad()
            
            img, label = sample['img'], sample['label']
            pred = model(img)
            
            loss = loss_fn(pred, label)  # CrossEntropyLoss expects class indices
            loss.backward()
            optimizer.step()
            
            loss_item = loss.item()
            loss_val += loss_item
            
            accuracy_val = accuracy(pred, label)
            acc_val += accuracy_val
        except Exception as e:
            # Log and skip problematic batch to avoid endless stall
            skipped_batches += 1
            print(f"Skipping batch due to error: {e}")
            continue
        
    pbar.set_description(f"Epoch {epoch+1}/{epochs} Loss: {loss_val/len(train_loader):.4f} Acc: {acc_val/len(train_loader):.4f} Skipped: {skipped_batches}")
    print(loss_val/len(train_loader), acc_val/len(train_loader))

100%|██████████| 1562/1562 [01:06<00:00, 23.38it/s]


0.6830153596729384 0.5513364276568502


100%|██████████| 1562/1562 [01:05<00:00, 23.88it/s]


0.6401907743816949 0.6321222791293214


100%|██████████| 1562/1562 [01:06<00:00, 23.59it/s]


0.6167383771699766 0.662291933418694


100%|██████████| 1562/1562 [01:05<00:00, 23.87it/s]


0.582407337297398 0.6954225352112676


100%|██████████| 1562/1562 [01:05<00:00, 23.88it/s]


0.5552855204032416 0.7191101152368758


100%|██████████| 1562/1562 [01:05<00:00, 23.86it/s]


0.5357778915865931 0.7324743918053778


100%|██████████| 1562/1562 [01:10<00:00, 22.27it/s]


0.5192595571279526 0.7464788732394366


100%|██████████| 1562/1562 [03:05<00:00,  8.42it/s]


0.5039716493362196 0.7554417413572343


100%|██████████| 1562/1562 [01:13<00:00, 21.27it/s]


0.49149217350687474 0.7623239436619719


100%|██████████| 1562/1562 [01:10<00:00, 22.14it/s]

0.478084067700767 0.7732874519846351



