In [70]:
import os
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torchvision.models import densenet121, DenseNet121_Weights
import matplotlib.pyplot as plt
from tqdm import tqdm

In [71]:
# Dataset
# --------------------
class COVIDxDataset(Dataset):
    def __init__(self, txt_file, img_dir, transform=None, use_crop=False):
        """
        txt_file: путь к train/val/test .txt файлу
        img_dir: папка с изображениями
        """
        self.df = pd.read_csv(
            txt_file, sep=" ", header=None,
            names=["filename", "class", "xmin", "ymin", "xmax", "ymax"]
        )
        self.img_dir = img_dir
        self.transform = transform
        self.use_crop = use_crop

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.img_dir, row["filename"])
        image = Image.open(img_path).convert("RGB")
        label = int(row["class"])

        if self.use_crop:
            xmin, ymin, xmax, ymax = row["xmin"], row["ymin"], row["xmax"], row["ymax"]
            image = image.crop((xmin, ymin, xmax, ymax))

        if self.transform:
            image = self.transform(image)

        return image, label

In [72]:
# Трансформации
# --------------------
transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])
])


In [73]:
# Пути к датасету
# --------------------
img_dir = "dataset/3A_images"  # папка с .png
train_txt = "dataset/train_COVIDx_CT-3A.txt"
val_txt = "dataset/val_COVIDx_CT-3A.txt"
test_txt = "dataset/test_COVIDx_CT-3A.txt"

train_dataset = COVIDxDataset(train_txt, img_dir, transform=transform)
val_dataset = COVIDxDataset(val_txt, img_dir, transform=transform)
test_dataset = COVIDxDataset(test_txt, img_dir, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [74]:
# Model
# --------------------
def get_model(num_classes=3, pretrained=True):
    if pretrained:
        weights = DenseNet121_Weights.DEFAULT
    else:
        weights = None
    model = densenet121(weights=weights)
    in_f = model.classifier.in_features
    model.classifier = nn.Linear(in_f, num_classes)
    return model

device = torch.device("cuda" if torch.cuda.is_available() else 
                      "mps" if torch.backends.mps.is_available() else 
                      "cpu")
print(f"Используется устройство: {device}")

model = get_model(num_classes=3, pretrained=True).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

Используется устройство: cuda


In [75]:
# Функции обучения и валидации
# --------------------
def train_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    loop = tqdm(dataloader, desc="Training", leave=False)
    for images, labels in loop:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        loop.set_postfix(loss=loss.item(), acc=(correct / total))
        
    return running_loss / total, correct / total

In [76]:
def evaluate(model, dataloader, criterion, device):
    model.eval()
    running_loss, correct, total = 0.0, 0, 0
    loop = tqdm(dataloader, desc="Validation", leave=False)
    with torch.no_grad():
        for images, labels in loop:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            loop.set_postfix(loss=loss.item(), acc=(correct / total))

    return running_loss / total, correct / total

In [77]:
# Тренировка
# --------------------
num_epochs = 4
train_losses, val_losses, test_losses = [], [], []
train_accuracies, val_accuracies, test_accuracies = [], [], []

best_val_acc = 0.0
best_model_path = "best_densenet_ct.pth"

for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = evaluate(model, val_loader, criterion, device)
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)

    train_losses.append(train_loss)
    val_losses.append(val_loss)
    test_losses.append(test_loss)

    train_accuracies.append(train_acc)
    val_accuracies.append(val_acc)
    test_accuracies.append(test_acc)

    # Сохраняем модель, если валидация улучшилась
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), best_model_path)
        print(f"🔹 Сохранена новая лучшая модель (val_acc={val_acc:.4f})")

    print(f"Epoch [{epoch+1}/{num_epochs}] "
          f"Train: loss={train_loss:.4f}, acc={train_acc:.4f} | "
          f"Val: loss={val_loss:.4f}, acc={val_acc:.4f} | "
          f"Test: loss={test_loss:.4f}, acc={test_acc:.4f}")

                                                                                            

🔹 Сохранена новая лучшая модель (val_acc=0.8644)
Epoch [1/4] Train: loss=0.0193, acc=0.9939 | Val: loss=0.4605, acc=0.8644 | Test: loss=0.4370, acc=0.8813


                                                                                            

KeyboardInterrupt: 

In [None]:
# Загрузка лучшей модели
# --------------------
model.load_state_dict(torch.load(best_model_path, map_location=device))
print(f"✅ Загружена лучшая модель с val_acc={best_val_acc:.4f}")

In [None]:
# Визуализация
# --------------------
plt.figure(figsize=(12, 5))

# Loss
plt.subplot(1, 2, 1)
plt.plot(train_losses, label="Train Loss")
plt.plot(val_losses, label="Val Loss")
plt.plot(test_losses, label="Test Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Loss curves")
plt.legend()

# Accuracy
plt.subplot(1, 2, 2)
plt.plot(train_accuracies, label="Train Acc")
plt.plot(val_accuracies, label="Val Acc")
plt.plot(test_accuracies, label="Test Acc")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Accuracy curves")
plt.legend()

plt.tight_layout()
plt.show()