In [7]:
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
import matplotlib.pyplot as plt

def show_images(images, titles=None, cmap=None):
    """ Показывает серию изображений. """
    n = len(images)
    fig, axs = plt.subplots(1, n, figsize=(15, 5))
    if n == 1:
        axs = [axs]
    for i, image in enumerate(images):
        axs[i].imshow(image.squeeze(), cmap=cmap)
        axs[i].axis('off')
        if titles is not None:
            axs[i].set_title(titles[i])
    plt.show()
    
class ImageToImageDataset(Dataset):
    def __init__(self, input_dir, target_dir=None, transform=None, target_transform=None, mode='train'):
        self.input_dir = input_dir
        self.target_dir = target_dir
        self.transform = transform
        self.target_transform = target_transform or transform
        self.mode = mode
        self.filenames = [f for f in os.listdir(input_dir) if f.endswith('.png')]

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        input_path = os.path.join(self.input_dir, self.filenames[idx])
        input_image = Image.open(input_path).convert('L')

        if self.transform:
            input_image = self.transform(input_image)

        if self.mode == 'train':
            target_path = os.path.join(self.target_dir, self.filenames[idx])
            target_image = Image.open(target_path).convert('L')

            if self.target_transform:
                target_image = self.target_transform(target_image)

            return input_image, target_image
        else:
            return input_image

import torchvision

def save_image(tensor, filename):
    """ Сохраняет тензор как изображение. """
    torchvision.utils.save_image(tensor, filename)
    
# Определение модели
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.conv3 = nn.Conv2d(32, 1, 3, padding=1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.sigmoid(self.conv3(x))
        return x

# Параметры
num_epochs = 6
batch_size = 4
learning_rate = 0.001

# Подготовка данных
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Нормализация если нужна
])

target_transform = transforms.Compose([
    transforms.ToTensor()  # Только преобразование в тензор для масок
])

train_dataset = ImageToImageDataset(input_dir="C:/Users/rosti/Desktop/data/train_images", target_dir="C:/Users/rosti/Desktop/data/train_lung_masks", transform=transform, target_transform =target_transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

model = SimpleCNN()
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Обучение модели
for epoch in range(num_epochs):
    for i, (inputs, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # Визуализация
        

        # Печать статистики обучения
        if (i+1) % 100 == 0:
        #     num_images_to_show = 1
        #     for j in range(num_images_to_show):
        #         save_image(inputs[j], f'output/epoch_{epoch}_{i+1}_input_{j}.png')
        #         save_image(outputs[j], f'output/epoch_{epoch}_{i+1}_output_{j}.png')
        #         save_image(labels[j], f'output/epoch_{epoch}_{i+1}_label_{j}.png')

            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')
model_path = "C:/Users/rosti/Desktop/model.pth"  # Укажите желаемый путь для сохранения модели
torch.save(model.state_dict(), model_path)
print(f"Model saved to {model_path}")


Epoch [1/6], Step [100/6750], Loss: 0.4247
Epoch [1/6], Step [200/6750], Loss: 0.3752
Epoch [1/6], Step [300/6750], Loss: 0.4444
Epoch [1/6], Step [400/6750], Loss: 0.4063
Epoch [1/6], Step [500/6750], Loss: 0.4648
Epoch [1/6], Step [600/6750], Loss: 0.3335
Epoch [1/6], Step [700/6750], Loss: 0.4067
Epoch [1/6], Step [800/6750], Loss: 0.4246
Epoch [1/6], Step [900/6750], Loss: 0.4453
Epoch [1/6], Step [1000/6750], Loss: 0.4185
Epoch [1/6], Step [1100/6750], Loss: 0.4509
Epoch [1/6], Step [1200/6750], Loss: 0.4479
Epoch [1/6], Step [1300/6750], Loss: 0.3825
Epoch [1/6], Step [1400/6750], Loss: 0.3538
Epoch [1/6], Step [1500/6750], Loss: 0.4054
Epoch [1/6], Step [1600/6750], Loss: 0.4167
Epoch [1/6], Step [1700/6750], Loss: 0.4212
Epoch [1/6], Step [1800/6750], Loss: 0.4278
Epoch [1/6], Step [1900/6750], Loss: 0.3742
Epoch [1/6], Step [2000/6750], Loss: 0.4307
Epoch [1/6], Step [2100/6750], Loss: 0.3613
Epoch [1/6], Step [2200/6750], Loss: 0.3383
Epoch [1/6], Step [2300/6750], Loss: 0.38