In [1]:
import os

import pandas as pd
import torch
from PIL import Image
from ignite.engine import Events
from ignite.engine import create_supervised_evaluator
from ignite.engine import create_supervised_trainer
from ignite.metrics import Accuracy
from ignite.metrics import Loss
from torch import nn
from torch.utils.data import DataLoader
from torchsummary import summary
from torchvision.transforms.v2 import Compose, PILToTensor, ToDtype

from datasets.LungDataset import LungDataset

In [2]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
NUM_EPOCHS = 5
BUTCH_SIZE = 20
IMAGE_DIR = '../data/train_images/'
MASK_DIR = '../data/train_lung_masks/'
TEST_IMAGE_DIR = '../data/test_images/'
TEST_MASK_DIR = '../data/test_lung_masks/'
LEARNING_RATE = 1e-3
USE_CHECKPOINT = False
GENERATE_ANSWER = False
TRAIN_MODEL = not GENERATE_ANSWER

In [None]:
transform = Compose([
    PILToTensor(),
    ToDtype(torch.float32, scale=True),
])

In [None]:
train_dataset = LungDataset(image_dir=IMAGE_DIR, mask_dir=MASK_DIR, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BUTCH_SIZE)

In [None]:
class NN(nn.Module):
    def __init__(self):
        super(NN, self).__init__()

        self.layer1 = torch.nn.Sequential(
            torch.nn.Conv2d(1, 32, 5),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(5, 5)
        )
        self.layer2 = torch.nn.Sequential(
            torch.nn.Conv2d(32, 128, 5),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(5, 5)
        )
        self.layer3 = torch.nn.Sequential(
            torch.nn.Conv2d(128, 512, 5),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(3, 3)
        )

        self.flatten = nn.Flatten(-3)
        self.linear1 = nn.Linear(512, 100, dtype=torch.bfloat16)
        self.batchnorm1 = nn.BatchNorm1d(100)
        self.linear2 = nn.Linear(100, 20, dtype=torch.bfloat16)
        self.batchnorm2 = nn.BatchNorm1d(20)
        self.linear3 = nn.Linear(20, 3, dtype=torch.bfloat16)
        self.batchnorm3 = nn.BatchNorm1d(3)

        self.dropout = nn.Dropout(0.4)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=-1)

In [None]:
model = NN().to(DEVICE)
summary(model, train_dataset[0][0].shape)

In [None]:
loss_function = nn.CrossEntropyLoss().to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
trainer = create_supervised_trainer(model, optimizer, loss_function)

In [None]:
def save_checkpoint(state, filename='checkpoint.pth.tar'):
    torch.save(state,filename)

def load_checkpoint(checkpoint, model):
    model.load_state_dict(checkpoint['state_dict'])
    
if USE_CHECKPOINT:
    load_checkpoint(torch.load('checkpoint-5.pth.tar', map_location=DEVICE), model)

In [1]:
metrics_dict = {
    'accuracy': Accuracy(),
    'loss': Loss(loss_function)
}

train_evaluator = create_supervised_evaluator(model, metrics_dict)

def compute_epoch_results():
    train_evaluator.run(train_loader)
    
def log_iter_loss(engine):
    print(f'Epoch[{engine.state.epoch}] - Iter[{engine.state.iteration}]: loss = {engine.state.output}')
    
def create_checkpoint():
    checkpoint = {
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict()
    }
    save_checkpoint(checkpoint)

def generate_test_answer():
    model.eval()

    answer = pd.DataFrame(columns=['id', 'target_feature'])
    test_images = os.listdir(TEST_IMAGE_DIR)
    
    for image_name in test_images:
        img_path = os.path.join(TEST_IMAGE_DIR, image_name)
        image = Image.open(img_path).convert('L')
        image = transform(image)
        image = image.to(device=DEVICE)
        
        with torch.no_grad():
            prediction = model(image.unsqueeze(1))
            answer.loc[answer.shape[0]] = [answer.shape[0], int(torch.argmax(prediction))]
    answer.to_csv("answer.csv", index=False)
    
    model.train()

SyntaxError: incomplete input (2959433018.py, line 27)

In [None]:
trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), log_iter_loss)
trainer.add_event_handler(Events.EPOCH_COMPLETED, compute_epoch_results)
trainer.add_event_handler(Events.EPOCH_COMPLETED, create_checkpoint)
trainer.add_event_handler(Events.COMPLETED, generate_test_answer)

In [None]:
if GENERATE_ANSWER:
    generate_test_answer()

In [None]:
if TRAIN_MODEL:
    trainer.run(train_loader, NUM_EPOCHS)