## Import modules

In [2]:
import os
import numpy as np
from tqdm import tqdm
from datetime import datetime
import matplotlib.pyplot as plt

from torchvision.datasets.cifar import CIFAR100
from torchvision.transforms import ToTensor

from config.models import MODEL_EXTRAS
from models.cls_hrnet import get_cls_net

import torch
import torch.nn as nn
from torchinfo import summary
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import TensorDataset, DataLoader

## Model setting

In [3]:
LOAD_WEIGHT = False # 기존 모델 가중치를 가져올지 여부
WEIGHT_PATH = "./save/20250220_152155/weight/model_epoch_200.pt" # 기존 모델 가중치 경로

## Device setting

In [20]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  
os.environ["CUDA_VISIBLE_DEVICES"]= "1"

def get_recommended_device():
    if torch.cuda.is_available():
        return "cuda"
    if torch.backends.mps.is_available():
        return "mps"
    return "cpu"

device = get_recommended_device()
print(device)

cpu


## Data setting

In [18]:
train = CIFAR100(root='./data', train=True, download=True, transform=ToTensor())
test = CIFAR100(root='./data', train=False, download=True, transform=ToTensor())

train_loader = DataLoader(train, batch_size=32, shuffle=True)
test_loader = DataLoader(test, batch_size=32, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [16]:
print(train)
print()
print(len(train[0]))
print()
print(train[0][1])

Dataset CIFAR100
    Number of datapoints: 50000
    Root location: ./data
    Split: Train
    StandardTransform
Transform: ToTensor()

2

19


## Model setting

In [None]:
class HRNet(nn.Module):
    def __init__(self, cfg: dict, backbone_version: str) -> None:
        super().__init__()
        self.cfg = cfg
        self.numeric_convert_block = Numeric_Convert_Block(in_features=cfg["numeric_in_features"])
        self.hori_convert_block = Pattern_Convert_Block(in_features=cfg["hori_in_feauters"])
        self.verti_convert_block = Pattern_Convert_Block(in_features=cfg["verti_in_features"])
        
        self.hrformer = HRFormer(self.backbone, cfg["in_channels"]) #hrformer
        self.regressor = Regressor(self.backbone["STAGE4"]["NUM_CHANNELS"][0], 1) #헤드(디코더)
    
    def forward(self, x:list) -> torch.Tensor:
        input_image_gis, input_image_ant, hori_input, verti_input, numeric_input = x
        
        # 수치 데이터, 패턴 데이터 -> 컨버터로 500x500으로 변환
        hori_convert = self.hori_convert_block(hori_input)
        verti_convert = self.verti_convert_block(verti_input)
        numeric_convert = self.numeric_convert_block(numeric_input)

        # 모든 데이터 concat
        con = torch.cat([input_image_gis, input_image_ant, numeric_convert, hori_convert, verti_convert], dim=1)
        
        y_hrt = self.hrformer(con) # 모델 입력 
        y = self.regressor(y_hrt[0]) # 복원

        return y


model = HRNet(CONFIG, HRT_VERSION).to(device)
summary(model)