In [None]:
from dataset import *
from train import *

import csv, random

import torch
import torchvision
import torchvision.transforms as transforms

random.seed(1234)

In [None]:
import warnings

warnings.filterwarnings('ignore')

In [None]:
train_dir = '../../data/.train/.task148/data/train/images/'
train_csv = '../../data/.train/.task148/data/train/train.csv'

In [None]:
distributions = {}
labels = ['10_콘크리트외벽', '20_조적외벽', '30_판넬외벽', '40_유리외벽', '50_기타외벽']
idxs = {29: ['0', '1', '2'], 31: ['0', '1'], 32: ['0', '1'], 36: ['', '부속건축물', '주건축물'], 
        44: ['', ' ', '강파이프구조', '경량철골구조', '공업화박판강구조(PEB)', '기타강구조', '기타구조', '기타조적구조', '기타철골철근콘크리트구조', '기타콘크리트구조', '목구조', '벽돌구조', '블록구조', '석구조', '시멘트블럭조', '일반목구조', '일반철골구조', '조립식판넬조', '조적구조', '철골구조', '철골철근콘크리트구조', '철골철근콘크리트합성구조', '철골콘크리트구조', '철근콘크리트구조', '콘크리트구조', '통나무구조', '트러스구조', '프리케스트콘크리트구조'], 
        50: ['', ' ', '(철근)콘크리트', '기와', '기타지붕', '슬레이트']}

with open(train_csv, 'r') as csvfile:
    for i, line in enumerate(csv.reader(csvfile)):
        if i != 0:
            ID = line[0]
            usage_list = []
            for idx in idxs:
                usage = [0]*len(idxs[idx])
                usage[idxs[idx].index(line[idx])] = 1
                usage_list += usage
            target = line[67]
            target = labels.index(target)
            
            if target not in distributions: distributions[target] = []
            distributions[target].append((ID, tuple(usage_list), target))

            
distributions[0] = random.sample(distributions[0], 4800)
#distributions[4] = random.sample(distributions[4], 4800)

for distribution in sorted(distributions):
    print(distribution, len(distributions[distribution]))

In [None]:
train_set = set()
val_set = set()

for i in range(len(labels)):
    temp = random.sample(distributions[i], len(distributions[i])//4)
    train_set.update(set(distributions[i])-set(temp))
    val_set.update(temp)
    
train = list(train_set)
val = list(val_set)

print(len(train), len(val))

In [None]:
batch_size = 32

epochs = 10
lr = 0.0001
weight_decay = 0.000001

In [None]:
train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5231, 0.5493, 0.5485], std=[0.2502, 0.2544, 0.2786])])

val_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5231, 0.5493, 0.5485], std=[0.2502, 0.2544, 0.2786])])


train_dataset = Train_148(infos=train, transform=train_transform)
val_dataset = Train_148(infos=val, transform=val_transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=128, shuffle=False)

In [None]:
model = Baseline(num_input=len(usage_list), num_classes=len(labels), save='./meta_')

In [None]:
model.train(train_loader, val_loader, epochs=epochs, lr=lr, weight_decay=weight_decay)