## 프레임워크 및 모듈 설치

In [2]:
# ! pip install simpleITK
# ! pip install tqdm
# ! pip install torch torchvision pandas torchmetrics

## 모듈참조

In [3]:
import torch
from torch import nn
import torchvision
from torch.utils.data import DataLoader, random_split
from torchvision.models import resnet50, resnext50_32x4d, mobilenet_v3_small
from torchmetrics.regression import MeanSquaredError, MeanAbsoluteError
import SimpleITK as sitk
import pandas as pd
from tqdm import tqdm
import os
import numpy as np
import torch.nn.utils.prune as prune

#### 사용가능한 리소스 확인

In [4]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cuda')

#### 파라미터 세팅
학습성능 효율성 및 딥러닝의 불확식성을 위한 시드값 정의 x

In [5]:
batch_size = 64
root = "./"

## 데이터 로드 및 분석

In [6]:
trainframe = pd.read_csv(f'{root}train.csv')
testframe = pd.read_csv(f'{root}test.csv')
submissionframe = pd.read_csv(f'{root}sample_submission.csv')
print(len(trainframe.columns))
print(len(submissionframe.columns))
print((trainframe.columns[2:]==submissionframe.columns[1:]).all()) # 제출파일에 이상이 있는지 체크
classnum = len(trainframe.loc[0].keys())-2 #클래스 개수

3469
3468
True


## 평가 메트릭 구현

In [None]:
class MeanCellCorrelation(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    def forward(self, pred, target):
        batch_size=pred.size()[0]
        targetvar= target-target.mean()
        predvar = pred-pred.mean()
        return (((targetvar*predvar)/ ((targetvar*targetvar).sqrt()* (predvar*predvar).sqrt()+1e-7)).mean(dim=1)).mean()

class MeanGenesCorrelation(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    def forward(self, pred, target):
        batch_size=pred.size()[0]
        targetvar= target-target.mean()
        predvar = pred-pred.mean()

        return (((targetvar*predvar)/ ((targetvar*targetvar).sqrt()* (predvar*predvar).sqrt()+1e-7)).mean(dim=1)).max()
    
class CustomLoss(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.metricf = nn.SmoothL1Loss(beta=2)
        self.metricf2 = nn.HuberLoss(delta=0.5)
    def forward(self, pred, target):
        metric = self.metricf(pred, target)
        metric = self.metricf2(pred, target)
        return metric 

## 유틸 클래스 및 함수정의
#### 딥러닝 데이터셋

In [8]:
class Dataset(object):
  def __init__(self, frame, datadir, trainmode=True) -> None:
    super().__init__()
    self.frame = frame
    self.datadir = datadir
    self.trainmode = trainmode
  def __getitem__(self, index):
    row = self.frame.loc[index]
    path = os.path.join(self.datadir, row['ID']+".png")
    img = sitk.ReadImage(path)
    img = sitk.GetArrayFromImage(img)
    label = row.iloc[2:].apply(lambda x: float(x))
    ret = {
      "img": torchvision.transforms.ToTensor()(img).type(torch.float32),
      "label": torch.Tensor(label)
    }
    return ret

  def __len__(self):
    return len(self.frame)

#### 데이터 분할 및 로더

In [9]:
dataset = Dataset(trainframe, f"{root}train")
trainset, validset = random_split(dataset, [0.9, 0.1])
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: x)
validloader = DataLoader(validset, batch_size=batch_size, shuffle=False, collate_fn=lambda x: x)
testloader = DataLoader(Dataset(testframe, f"{root}test", False), batch_size=batch_size, shuffle=False, collate_fn=lambda x: x)

## 모델 정의
#### 어텐션 모델 적용 및 응용

In [10]:
class CustomResNet(torch.nn.Module):
  def __init__(self, outchannel) -> None:
    super().__init__()
    self.net = resnet50()
    self.attention = CBAM(outchannel)
    self.net.fc = torch.nn.Sequential(
        torch.nn.Linear(in_features=2048, out_features=outchannel, bias=False))
  def forward(self, x):
    x = self.net(x)
    x = self.attention(x.unsqueeze(-1).unsqueeze(-1))
    return x

class CustomResNext(torch.nn.Module):
  def __init__(self, outchannel) -> None:
    super().__init__()
    self.net = resnext50_32x4d()
    self.attention = CBAM(outchannel)
    self.net.fc = torch.nn.Sequential(
        torch.nn.Linear(in_features=2048, out_features=outchannel, bias=False))
  def forward(self, x):
    x = self.net(x)
    x = self.attention(x.unsqueeze(-1).unsqueeze(-1))
    return x
class CustomMobileNet(torch.nn.Module):
  def __init__(self, outchannel) -> None:
    super().__init__()
    self.net = mobilenet_v3_small()
    self.attention = CBAM(outchannel)
    self.net.classifier = torch.nn.Sequential(
       nn.Linear(in_features=576, out_features=1024, bias=True),
       nn.Hardswish(),
      #  nn.Dropout(p=0.2, inplace=True)
       nn.Linear(in_features=1024, out_features=outchannel, bias=True))
  def forward(self, x):
    x = self.net(x)
    x = self.attention(x.unsqueeze(-1).unsqueeze(-1))
    return x
class CBAM(nn.Module):
    def __init__(self, channel, reduction=16):
        super(CBAM, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        # Channel Attention
        self.channel_fc = nn.Sequential(
            nn.Linear(channel, channel // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel)
        )
        
        # Spatial Attention
        self.spatial_conv = nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False)
        
    def forward(self, x):
        b, c, w, h = x.size()
        
        # Channel Attention
        avg_out = self.channel_fc(self.avg_pool(x).view(b, c))
        max_out = self.channel_fc(self.max_pool(x).view(b, c))
        channel_att = torch.sigmoid(avg_out + max_out).view(b, c, 1, 1)
        
        x = x * channel_att
        
        # Spatial Attention
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        spatial_att = torch.sigmoid(self.spatial_conv(torch.cat([avg_out, max_out], dim=1)))
        
        return (x * spatial_att).squeeze()

#### 학습 및 평가 함수 정의

In [11]:
def valid(net, validloader, device):
    maef = MeanAbsoluteError().to(device)
    msef = MeanSquaredError().to(device)
    mccf = MeanCellCorrelation().to(device)
    mgcf = MeanGenesCorrelation().to(device)
    net.eval()
    mae = 0
    mse = 0
    mcc = 0
    mgc = 0
    length = len(validloader)
    with torch.no_grad():
        for batch in tqdm(validloader):
          
          img= [b["img"] for b in batch]
          label = [b["label"] for b in batch]
          img = torch.stack(img).to(device)
          label = torch.stack(label).to(device)
          out = net(img)
          mae += maef(out, label).item()
          mse += msef(out, label).item()
          mcc += mccf(out, label).item()
          mgc += mgcf(out, label).item()

    return {"mse": mse/length, "mae": mae/length, "mcc": mcc/length, "mgc": mgc/length}

def train(net, loader, validloader, epoch, optimizer, loss_fn, device, save_dir ):
  length = len(loader)
  history = {"mcc":[], "mgc":[], "loss":[],"mse":[]}
  for e in range(epoch):
      losses=0
      for batch in tqdm(loader):
          net.train()
          img= [b["img"] for b in batch]
          label = [b["label"].to(device) for b in batch]
          img = torch.stack(img).to(device)
          label = torch.stack(label).to(device)
          output = net(img)
          loss = loss_fn(output, label)
          loss.backward()
          optimizer.step()
          optimizer.zero_grad()
          losses+=loss.item()
  
      hist = valid(net, validloader, device)
      mae = hist["mae"]
      mse = hist["mse"]
      mcc = hist["mcc"]
      mgc = hist["mgc"]
      print(f'epoch: {e+1}, loss: {losses/length} mse: {mse} mae: {mae} mcc: {mcc} mgc: {mgc}')
      if e == 0:
        torch.save(net.state_dict(), os.path.join(save_dir, "net.pt"))
        print("model is updated")
      else:
        eval = np.array(history["mse"])
        if eval.min() >= mse:
           torch.save(net.state_dict(), os.path.join(save_dir, "net.pt"))
           print("model is updated")
      history["mcc"].append(mcc); history["mgc"].append(mgc); history["loss"].append(losses/length); history["mse"].append(mse)
  return history

#### 실제학습을위한 모델 준비

In [12]:
net = CustomResNet(classnum)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)
loss_fn = CustomLoss().to(DEVICE)
net.float()
net.to(DEVICE)
print("model is ready")

model is ready


## 학습

In [13]:
train(net, trainloader, validloader, 30, optimizer, loss_fn, DEVICE, f"{root}Models")

  "label": torch.Tensor(label)
  7%|▋         | 7/99 [00:04<01:00,  1.53it/s]


KeyboardInterrupt: 

## 제출및 결과

In [14]:
def submission(net, testloader, device, keys, subframePath = f'{root}sample_submission.csv'):
  net.eval()
  submit = pd.read_csv(subframePath)
  submit.iloc[:, 1:] = inference(net, testloader, device).astype(np.float32)
  submit.to_csv('./result.csv', index=False)
  return submit
def inference(model, test_loader, device):
    model.eval()
    preds = []
    with torch.no_grad():
        for batch in tqdm(test_loader):
            img= [b["img"].to(device) for b in batch]
            img = torch.stack(img)
            pred = model(img)
            preds.append(pred.detach().cpu())
    
    preds = torch.cat(preds).numpy()

    return preds

#### 가지치기

In [15]:
def pruning(net:nn.Module):
    for name, module in net.named_modules():
        # 모든 2D-conv 층의 20% 연결에 대해 가지치기 기법을 적용
        if isinstance(module, torch.nn.Conv2d):
            prune.l1_unstructured(module, name='weight', amount=0.2)
            # prune.l1_unstructured(module, name='bias', amount=2)
        # 모든 선형 층의 40% 연결에 대해 가지치기 기법을 적용
        elif isinstance(module, torch.nn.Linear):
            prune.l1_unstructured(module, name='weight', amount=0.3)
            # prune.l1_unstructured(module, name='bias', amount=3)
        

#### 체크포인트 모델 로드 및 결과

In [16]:
net = CustomResNet(classnum)
net.float()
net.to(DEVICE)
net.load_state_dict(torch.load(f"{root}/Models/net.pt", weights_only=True))
pruning(net)
submission(net, testloader, DEVICE, submissionframe.columns[1:])

100%|██████████| 36/36 [00:08<00:00,  4.29it/s]


Unnamed: 0,ID,AL645608.7,HES4,TNFRSF18,TNFRSF4,SDF4,ACAP3,INTS11,MXRA8,AL391244.2,...,MT-ATP8,MT-ATP6,MT-CO3,MT-ND3,MT-ND4L,MT-ND4,MT-ND5,MT-CYB,BX004987.1,AL592183.1
0,TEST_0000,-0.000995,0.026969,-0.002195,-0.001081,0.646813,0.105283,0.203616,0.026514,-0.000420,...,0.290498,3.642795,3.859397,3.167519,1.478432,3.378892,1.713870,3.356684,-0.000674,-0.003059
1,TEST_0001,0.001193,0.017016,-0.003633,-0.001891,0.630105,0.085901,0.181992,0.027106,-0.003029,...,0.679742,4.191400,4.208060,3.702670,2.150257,3.981610,2.268077,3.813267,-0.000133,-0.004224
2,TEST_0002,-0.000641,0.024702,-0.003523,-0.002348,0.686605,0.092786,0.202835,0.026404,-0.001712,...,0.467365,3.834481,3.961667,3.337365,1.741116,3.602150,1.916714,3.497329,-0.000962,-0.006442
3,TEST_0003,0.000324,0.031164,-0.001641,-0.000882,0.612918,0.088884,0.194346,0.029698,-0.001187,...,0.362832,3.759809,3.897166,3.184016,1.621146,3.476542,1.827107,3.403340,-0.000438,-0.004821
4,TEST_0004,-0.002218,0.047416,-0.005217,-0.002718,0.672703,0.100077,0.259149,0.036827,0.001217,...,0.278126,3.750371,4.108978,3.457803,1.486155,3.391016,1.772501,3.539382,-0.004026,-0.008743
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2272,TEST_2272,0.000223,0.016898,-0.003225,-0.001345,0.667331,0.089585,0.176842,0.023023,-0.003823,...,0.564115,3.975902,4.020772,3.484323,1.960104,3.737813,2.131976,3.592626,-0.000460,-0.005816
2273,TEST_2273,0.000481,0.006935,-0.003336,-0.000497,0.605432,0.082924,0.169376,0.028189,-0.003892,...,0.727561,4.226516,4.203852,3.734321,2.242800,3.989810,2.361582,3.797792,-0.000482,-0.003150
2274,TEST_2274,0.000502,0.027697,-0.003052,-0.002865,0.648260,0.089286,0.207742,0.025825,-0.002131,...,0.485581,3.882975,3.998059,3.371432,1.794583,3.638407,2.004796,3.492051,-0.001242,-0.004674
2275,TEST_2275,-0.000739,0.031264,-0.002814,-0.003093,0.695549,0.098013,0.227442,0.030719,-0.000653,...,0.324420,3.620099,3.850816,3.140908,1.462913,3.395839,1.687703,3.319991,-0.001508,-0.006086
