## 프레임워크 및 모듈 설치

In [None]:
# ! pip install simpleITK
# ! pip install tqdm
# ! pip install torch torchvision pandas torchmetrics

## 모듈참조

In [None]:
import torch
from torch import nn
import torchvision
from torch.utils.data import DataLoader, random_split
from torchvision.models import resnet50
from torchmetrics.regression import MeanSquaredError, MeanAbsoluteError
import SimpleITK as sitk
import pandas as pd
from tqdm import tqdm
import os
import numpy as np

#### 사용가능한 리소스 확인

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cuda')

#### 파라미터 세팅
학습성능 효율성 및 딥러닝의 불확식성을 위한 시드값 정의 x

In [None]:
batch_size = 16
root = "./"

## 데이터 로드 및 분석

In [None]:
trainframe = pd.read_csv(f'{root}train.csv')
testframe = pd.read_csv(f'{root}test.csv')
submissionframe = pd.read_csv(f'{root}sample_submission.csv')
print(len(trainframe.columns))
print(len(submissionframe.columns))
print((trainframe.columns[2:]==submissionframe.columns[1:]).all()) # 제출파일에 이상이 있는지 체크
classnum = len(trainframe.loc[0].keys())-2 #클래스 개수

3469
3468
True


## 평가 메트릭 구현

In [6]:
class MeanCellCorrelation(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    def forward(self, pred, target):
        batch_size=pred.size()[0]
        targetvar= target-target.mean()
        predvar = pred-pred.mean()
        return (((targetvar*predvar)/ ((targetvar*targetvar).sqrt()* (predvar*predvar).sqrt()+1e-7)).mean(dim=1)).mean()

class MeanGenesCorrelation(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    def forward(self, pred, target):
        batch_size=pred.size()[0]
        targetvar= target-target.mean()
        predvar = pred-pred.mean()

        return (((targetvar*predvar)/ ((targetvar*targetvar).sqrt()* (predvar*predvar).sqrt()+1e-7)).mean(dim=1)).max()
    
class CustomLoss(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.metricf = nn.SmoothL1Loss(beta=2)
        self.metricf2 = nn.HuberLoss(delta=0.5)
    def forward(self, pred, target):
        metric = self.metricf(pred, target)
        metric = self.metricf2(pred, target)
        return metric 

## 유틸 클래스 및 함수정의
#### 딥러닝 데이터셋

In [8]:
class Dataset(object):
  def __init__(self, frame, datadir, trainmode=True) -> None:
    super().__init__()
    self.frame = frame
    self.datadir = datadir
    self.trainmode = trainmode
  def __getitem__(self, index):
    row = self.frame.loc[index]
    path = os.path.join(self.datadir, row['ID']+".png")
    img = sitk.ReadImage(path)
    img = sitk.GetArrayFromImage(img)
    label = row.iloc[2:].apply(lambda x: float(x))
    ret = {
      "img": torchvision.transforms.ToTensor()(img).type(torch.float32),
      "label": torch.Tensor(label)
    }
    return ret

  def __len__(self):
    return len(self.frame)

#### 데이터 분할 및 로더

In [9]:
dataset = Dataset(trainframe, f"{root}train")
trainset, validset = random_split(dataset, [0.9, 0.1])
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: x)
validloader = DataLoader(validset, batch_size=batch_size, shuffle=False, collate_fn=lambda x: x)
testloader = DataLoader(Dataset(testframe, f"{root}test", False), batch_size=batch_size, shuffle=False, collate_fn=lambda x: x)

## 모델 정의
#### 어텐션 모델 적용 및 응용

In [10]:
class CustomResNet(torch.nn.Module):
  def __init__(self, outchannel) -> None:
    super().__init__()
    self.net = resnet50()
    self.attention = CBAM(outchannel)
    self.net.fc = torch.nn.Sequential(
        torch.nn.Dropout(p=0.1, inplace=True),
        torch.nn.Linear(in_features=2048, out_features=outchannel, bias=False))
  def forward(self, x):
    x = self.net(x)
    x = self.attention(x.unsqueeze(-1).unsqueeze(-1))
    return x
  
class CBAM(nn.Module):
    def __init__(self, channel, reduction=16):
        super(CBAM, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        # Channel Attention
        self.channel_fc = nn.Sequential(
            nn.Linear(channel, channel // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel)
        )
        
        # Spatial Attention
        self.spatial_conv = nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False)
        
    def forward(self, x):
        b, c, w, h = x.size()
        
        # Channel Attention
        avg_out = self.channel_fc(self.avg_pool(x).view(b, c))
        max_out = self.channel_fc(self.max_pool(x).view(b, c))
        channel_att = torch.sigmoid(avg_out + max_out).view(b, c, 1, 1)
        
        x = x * channel_att
        
        # Spatial Attention
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        spatial_att = torch.sigmoid(self.spatial_conv(torch.cat([avg_out, max_out], dim=1)))
        
        return (x * spatial_att).squeeze()

#### 학습 및 평가 함수 정의

In [None]:
def valid(net, validloader, device):
    maef = MeanAbsoluteError().to(device)
    msef = MeanSquaredError().to(device)
    mccf = MeanCellCorrelation().to(device)
    mgcf = MeanGenesCorrelation().to(device)
    net.eval()
    mae = 0
    mse = 0
    mcc = 0
    mgc = 0
    length = len(validloader)
    with torch.no_grad():
        for batch in tqdm(validloader):
          
          img= [b["img"].to(device) for b in batch]
          label = [b["label"].to(device) for b in batch]
          img = torch.stack(img)
          label = torch.stack(label)
          out = net(img)
          mae += maef(out, label).item()
          mse += msef(out, label).item()
          mcc += mccf(out, label).item()
          mgc += mgcf(out, label).item()

    return {"mse": mse/length, "mae": mae/length, "mcc": mcc/length, "mgc": mgc/length}

def train(net, loader, validloader, epoch, optimizer, loss_fn, device, save_dir ):
  length = len(loader)
  history = {"mcc":[], "mgc":[], "loss":[],"mse":[]}
  for e in range(epoch):
      losses=0
      for batch in tqdm(loader):
          net.train()
          img= [b["img"] for b in batch]
          label = [b["label"].to(device) for b in batch]
          img = torch.stack(img).to(device)
          label = torch.stack(label).to(device)
          output = net(img)
          loss = loss_fn(output, label)
          loss.backward()
          optimizer.step()
          optimizer.zero_grad()
          losses+=loss.item()
  
      hist = valid(net, validloader, device)
      mae = hist["mae"]
      mse = hist["mse"]
      mcc = hist["mcc"]
      mgc = hist["mgc"]
      print(f'epoch: {e+1}, loss: {losses/length} mse: {mse} mae: {mae} mcc: {mcc} mgc: {mgc}')
      if e == 0:
        torch.save(net.state_dict(), os.path.join(save_dir, "net.pt"))
        print("model is updated")
      else:
        eval = np.array(history["mse"])
        if eval.min() >= mse:
           torch.save(net.state_dict(), os.path.join(save_dir, "net.pt"))
           print("model is updated")
      history["mcc"].append(mcc); history["mgc"].append(mgc); history["loss"].append(losses/length); history["mse"].append(mse)
  return history

#### 실제학습을위한 모델 준비

In [16]:
net = CustomResNet(classnum)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)
loss_fn = CustomLoss().to(DEVICE)
net.float()
net.to(DEVICE)
print("model is ready")

model is ready


## 학습

In [13]:
train(net, trainloader, validloader, 80, optimizer, loss_fn, DEVICE, f"{root}Models")

  "label": torch.Tensor(label)
100%|██████████| 394/394 [00:59<00:00,  6.64it/s]
100%|██████████| 44/44 [00:04<00:00,  9.03it/s]


epoch: 1, loss: 0.020740919685000695 mse: 0.04757165053690022 mae: 0.08682807788930157 mcc: 0.754357944835316 mgc: 0.7853068912571127
model is updated


100%|██████████| 394/394 [00:58<00:00,  6.69it/s]
100%|██████████| 44/44 [00:04<00:00,  9.45it/s]


epoch: 2, loss: 0.019463650333851123 mse: 0.046899140304462475 mae: 0.0863194213333455 mcc: 0.7558432370424271 mgc: 0.8005370633168654
model is updated


100%|██████████| 394/394 [01:01<00:00,  6.40it/s]
100%|██████████| 44/44 [00:04<00:00,  9.84it/s]


epoch: 3, loss: 0.0193318032522492 mse: 0.046493556519800965 mae: 0.08974541283466599 mcc: 0.7383367229591716 mgc: 0.7946212115612897
model is updated


100%|██████████| 394/394 [01:00<00:00,  6.52it/s]
100%|██████████| 44/44 [00:04<00:00,  9.79it/s]


epoch: 4, loss: 0.019258778140440507 mse: 0.04688257935710929 mae: 0.08792756667191332 mcc: 0.751555017449639 mgc: 0.7976691682230342


100%|██████████| 394/394 [01:01<00:00,  6.45it/s]
100%|██████████| 44/44 [00:04<00:00,  9.75it/s]


epoch: 5, loss: 0.019197680642308316 mse: 0.04668862745165825 mae: 0.08950595947151835 mcc: 0.7448597658764232 mgc: 0.7981776838952844


100%|██████████| 394/394 [01:00<00:00,  6.54it/s]
100%|██████████| 44/44 [00:04<00:00,  9.82it/s]


epoch: 6, loss: 0.019167903709097715 mse: 0.04734226477078416 mae: 0.08521600270813162 mcc: 0.7663702084259554 mgc: 0.8219195902347565


100%|██████████| 394/394 [01:01<00:00,  6.45it/s]
100%|██████████| 44/44 [00:04<00:00,  9.51it/s]


epoch: 7, loss: 0.019133534469183176 mse: 0.04646640321747823 mae: 0.0907300626012412 mcc: 0.7378598045219075 mgc: 0.8035046566616405
model is updated


100%|██████████| 394/394 [01:00<00:00,  6.48it/s]
100%|██████████| 44/44 [00:04<00:00,  9.88it/s]


epoch: 8, loss: 0.019108942660224164 mse: 0.0467190837318247 mae: 0.08720715953545137 mcc: 0.75529645383358 mgc: 0.8093236915089868


100%|██████████| 394/394 [01:00<00:00,  6.48it/s]
100%|██████████| 44/44 [00:04<00:00,  9.82it/s]


epoch: 9, loss: 0.019084869259921124 mse: 0.0463532999327237 mae: 0.08753386986526576 mcc: 0.750680301677097 mgc: 0.8194150978868658
model is updated


100%|██████████| 394/394 [01:00<00:00,  6.48it/s]
100%|██████████| 44/44 [00:04<00:00,  9.84it/s]


epoch: 10, loss: 0.019038026320374556 mse: 0.0471722778271545 mae: 0.09035758623345332 mcc: 0.7377400303428824 mgc: 0.821353405714035


100%|██████████| 394/394 [00:58<00:00,  6.74it/s]
100%|██████████| 44/44 [00:04<00:00,  9.83it/s]


epoch: 11, loss: 0.01900999821476676 mse: 0.0471391019157388 mae: 0.08879066207869486 mcc: 0.7476014684547078 mgc: 0.8167427737604488


100%|██████████| 394/394 [00:58<00:00,  6.73it/s]
100%|██████████| 44/44 [00:04<00:00,  9.83it/s]


epoch: 12, loss: 0.018966744488417196 mse: 0.047036417590623554 mae: 0.08735913597047329 mcc: 0.7553093067624352 mgc: 0.8221081549471075


 54%|█████▍    | 214/394 [00:31<00:26,  6.74it/s]


KeyboardInterrupt: 

## 제출및 결과

In [14]:
def submission(net, testloader, device, keys, subframePath = f'{root}sample_submission.csv'):
  net.eval()
  submit = pd.read_csv(subframePath)
  submit.iloc[:, 1:] = inference(net, testloader, device).astype(np.float32)
  submit.to_csv('./result.csv', index=False)
  return submit
def inference(model, test_loader, device):
    model.eval()
    preds = []
    with torch.no_grad():
        for batch in tqdm(test_loader):
            img= [b["img"].to(device) for b in batch]
            img = torch.stack(img)
            pred = model(img)
            preds.append(pred.detach().cpu())
    
    preds = torch.cat(preds).numpy()

    return preds

#### 체크포인트 모델 로드 및 결과

In [15]:
net.load_state_dict(torch.load(f"{root}/Models/net.pt"))
submission(net, testloader, DEVICE, submissionframe.columns[1:])

100%|██████████| 143/143 [00:07<00:00, 19.51it/s]


Unnamed: 0,ID,AL645608.7,HES4,TNFRSF18,TNFRSF4,SDF4,ACAP3,INTS11,MXRA8,AL391244.2,...,MT-ATP8,MT-ATP6,MT-CO3,MT-ND3,MT-ND4L,MT-ND4,MT-ND5,MT-CYB,BX004987.1,AL592183.1
0,TEST_0000,-0.000328,0.033912,0.000413,-0.000858,0.155625,0.026579,0.091520,0.038175,0.001255,...,0.130246,3.652415,4.063092,3.036614,0.712508,3.625293,1.190471,3.521472,-0.000270,0.000689
1,TEST_0001,-0.000148,0.015827,-0.000187,0.003037,0.660968,0.069624,0.181265,0.062043,-0.000183,...,0.229034,3.200570,3.393326,2.687198,0.880518,3.128765,1.241677,2.905874,0.000151,-0.000033
2,TEST_0002,-0.000049,0.017929,-0.000177,0.003634,0.637279,0.067224,0.168017,0.060014,-0.000046,...,0.242973,3.219442,3.390536,2.712713,0.861027,3.144180,1.225822,2.919336,0.000111,0.000316
3,TEST_0003,-0.000289,0.024179,0.000924,0.002239,0.556892,0.059299,0.156795,0.065863,0.000461,...,0.131123,3.050404,3.263935,2.499124,0.557684,2.961166,0.941910,2.749718,-0.000274,0.000730
4,TEST_0004,-0.000608,0.027104,-0.000253,0.002477,0.649496,0.077538,0.180734,0.066633,-0.000031,...,0.319531,3.616661,4.009782,3.272632,1.168183,3.524832,1.587888,3.427722,-0.000052,0.001007
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2272,TEST_2272,-0.000099,0.016589,0.000280,0.001853,0.618858,0.072265,0.161873,0.048846,0.000288,...,0.249873,3.253192,3.452406,2.798342,0.902863,3.205551,1.267059,3.011997,0.000019,0.000473
2273,TEST_2273,-0.000092,0.013077,0.000554,0.001202,0.613543,0.066129,0.143283,0.043573,0.000286,...,0.408829,3.487122,3.606279,3.007125,1.136255,3.450487,1.477257,3.186907,-0.000107,0.000389
2274,TEST_2274,-0.000167,0.018747,0.000422,0.002381,0.598075,0.059240,0.163373,0.063571,0.000130,...,0.176617,3.102651,3.292513,2.550854,0.685026,3.033775,1.067981,2.797505,-0.000096,0.000508
2275,TEST_2275,-0.000004,0.020972,-0.000827,0.005164,0.655012,0.069091,0.179805,0.063583,-0.000408,...,0.225716,3.196722,3.383302,2.687022,0.816230,3.090373,1.205319,2.893295,0.000214,-0.000717
