In [1]:
# IMPORTS
# Numerical Operations
import math
import numpy as np

# Reading/Writing Data
import pandas as pd
import os
import csv

# For Progress Bar
from tqdm import tqdm

# Pytorch
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split

# For plotting learning curve
from torch.utils.tensorboard import SummaryWriter

# Other Common Modules
from datetime import datetime, timedelta
import importlib

# IMPORT COMPLETE
print("Imports Done")

KeyboardInterrupt: 

In [None]:
if 'google.colab' in str(get_ipython()):
    print('Running on CoLab')
    timenow = datetime.now()+timedelta(hours=8)
    from google.colab import drive
    drive.mount('/content/drive')
    os.chdir("/content/drive/MyDrive/Chronical/2023Spring/ML_drive/MLHW3")
else:
    print('Running on Local')
    timenow = datetime.now()

In [None]:
import Config
importlib.reload(Config)
from Config import Config

In [None]:
Config.base_path = os.getcwd()
Config.data_path = os.path.join(Config.base_path, "data")
Config.save_path = os.path.join(Config.base_path, ".model")
Config.output_path = os.path.join(Config.base_path, "output")

if not os.path.isdir(Config.save_path):
    os.mkdir(Config.save_path)
if not os.path.isdir(Config.output_path):
    os.mkdir(Config.output_path)

Config.time_string = f"{timenow.hour:02d}{timenow.minute:02d}{timenow.month:02d}{timenow.day:02d}"

print(f"{Config.base_path=}")
print(f"{Config.data_path=}")
print(f"{Config.save_path=}")
print(f"{Config.output_path=}")
print(f"{Config.time_string=}")

In [None]:
# RANDOMNESS FIXED
from utils import fix_randomness
fix_randomness(Config.seed)

# Dataset and DataLoader

In [None]:
import ImageParser
importlib.reload(ImageParser)

In [None]:
train_image_parser = ImageParser.ImageParser("train")
valid_image_parser = ImageParser.ImageParser("valid")

In [None]:
train_dataset = ImageParser.ImageDataset(train_image_parser.image_list, train_image_parser.label_list)
valid_dataset = ImageParser.ImageDataset(valid_image_parser.image_list, valid_image_parser.label_list)

In [None]:
Config.train_loader = DataLoader(train_dataset, Config.batch_size, shuffle=True)
Config.valid_loader = DataLoader(valid_dataset, Config.batch_size, shuffle=True)

# Models

In [None]:
import ImageConvNet
importlib.reload(ImageConvNet)

In [None]:
Config.model = ImageConvNet.ImageConvNet().to(Config.device)
Config.criterion = nn.CrossEntropyLoss
Config.optimizer = torch.optim.Adam(Config.model.parameters(), lr=Config.learning_rate)
print("Model, Criterion, Optimizer complete")

# Training

In [None]:
import Trainer
importlib.reload(Trainer)

In [None]:
trainer = Trainer.Trainer()
trainer.train()