## 1) 기본 세팅

In [1]:
import os
import torch, torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader,WeightedRandomSampler
from torchvision import datasets, transforms
from torchvision.models import densenet121, DenseNet121_Weights
import tqdm

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

IMG_SIZE = 512
DATA_ROOT = 'chest_xray'
BATCH_TRAIN = 16

w = DenseNet121_Weights.IMAGENET1K_V1


## 2) 데이터 불러오기 및 전처리 

In [3]:
pre = w.transforms()          
train_tf = transforms.Compose([
transforms.Grayscale(num_output_channels=3),
transforms.Resize((IMG_SIZE,IMG_SIZE)),
transforms.RandomHorizontalFlip(p = 0.5),
transforms.ColorJitter(brightness=0.1,contrast=0.1),
transforms.ToTensor(),
transforms.Normalize(pre.mean,pre.std)
])

val_tf = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((IMG_SIZE,IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(pre.mean,pre.std)
])

train_ds = datasets.ImageFolder("chest_xray/train", transform=train_tf)
val_ds   = datasets.ImageFolder("chest_xray/val",   transform=val_tf)
test_ds   = datasets.ImageFolder("chest_xray/test",   transform=val_tf)


## 3) 클래스 불균형 처리

In [4]:
import numpy as np
labels = np.array(train_ds.targets)
class_counts = np.bincount(labels)
class_weights = 1.0/ (class_counts + 1e-6)
sample_weights = class_weights[labels]
len(sample_weights)
sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

## 4) 배치용 데이터 분할

In [5]:
train_dl = DataLoader(train_ds, batch_size=BATCH_TRAIN, sampler = sampler, num_workers=4)
val_dl = DataLoader(val_ds, batch_size=BATCH_TRAIN, shuffle=False, num_workers=4)
test_dl = DataLoader(test_ds, batch_size=BATCH_TRAIN, shuffle=False, num_workers=4)

In [118]:
model = densenet121(w)
model.classifier = nn.Linear(model.classifier.in_features, 2)
model.to(device)

epochs = 20
lr = 1e-3
l2 = 1e-5
criterion = nn.CrossEntropyLoss()
optim = torch.optim.AdamW(model.parameters(),lr = lr, weight_decay=l2)



## 5) 모델 학습

In [119]:
step = 0
writer = SummaryWriter()

model.train()
for epoch in range(epochs):
    for data, label in tqdm.tqdm(train_dl):
        optim.zero_grad()
        data, label = data.to(device), label.to(device)
        pred = model(data)
        loss = criterion(pred,label)
        loss.backward()

        writer.add_scalar("Loss/train", loss.item(), step)
        step += 1
        optim.step()
        
    print(f"{epoch + 1} loss : {loss.item()}")

from datetime import datetime
now = datetime.now() 
timestamp = now.strftime("%Y-%m-%d_%H-%M-%S")
torch.save(model.state_dict(), f"model_{timestamp}.pth")

100%|██████████| 326/326 [00:43<00:00,  7.50it/s]


1 loss : 0.06262947618961334


100%|██████████| 326/326 [00:43<00:00,  7.47it/s]


2 loss : 0.028551165014505386


100%|██████████| 326/326 [00:43<00:00,  7.44it/s]


3 loss : 0.4476194381713867


100%|██████████| 326/326 [00:43<00:00,  7.45it/s]


4 loss : 0.012892303988337517


100%|██████████| 326/326 [00:43<00:00,  7.45it/s]


5 loss : 0.01439596340060234


100%|██████████| 326/326 [00:43<00:00,  7.48it/s]


6 loss : 0.01873016357421875


100%|██████████| 326/326 [00:43<00:00,  7.47it/s]


7 loss : 0.028534971177577972


100%|██████████| 326/326 [00:43<00:00,  7.42it/s]


8 loss : 0.005746419541537762


100%|██████████| 326/326 [00:43<00:00,  7.43it/s]


9 loss : 0.6241193413734436


100%|██████████| 326/326 [00:43<00:00,  7.45it/s]


10 loss : 0.0026556539814919233


100%|██████████| 326/326 [00:43<00:00,  7.46it/s]


11 loss : 0.005396376829594374


100%|██████████| 326/326 [00:43<00:00,  7.44it/s]


12 loss : 0.05081372335553169


100%|██████████| 326/326 [00:43<00:00,  7.47it/s]


13 loss : 0.035489197820425034


100%|██████████| 326/326 [00:43<00:00,  7.43it/s]


14 loss : 0.07934334129095078


100%|██████████| 326/326 [00:43<00:00,  7.49it/s]


15 loss : 0.012704821303486824


100%|██████████| 326/326 [00:43<00:00,  7.49it/s]


16 loss : 0.00034952996065840125


100%|██████████| 326/326 [00:43<00:00,  7.45it/s]


17 loss : 0.00923384539783001


100%|██████████| 326/326 [00:43<00:00,  7.46it/s]


18 loss : 0.012179488316178322


100%|██████████| 326/326 [00:43<00:00,  7.46it/s]


19 loss : 0.0008390186703763902


100%|██████████| 326/326 [00:43<00:00,  7.42it/s]

20 loss : 0.08627404272556305





## 6) 모델 평가

In [6]:
from sklearn.metrics import f1_score 
model = densenet121(w) 
model.classifier = torch.nn.Linear(model.classifier.in_features, 2) 
model.load_state_dict(torch.load("model_f1_9123.pth", map_location=device))
model.to(device)

all_labels = []
all_preds = []
model.eval()
with torch.no_grad(): 
    for data, label in test_dl: 
        data, label = data.to(device), label.to(device) 
        outputs = model(data) 
        _, predicted = torch.max(outputs, 1) 
        all_labels.extend(label.cpu().numpy()) 
        all_preds.extend(predicted.cpu().numpy()) 

f1 = f1_score(all_labels, all_preds, average='binary') 
print(f"F1 Score: {f1}")

F1 Score: 0.9123222748815166
