In [12]:
import os.path
from typing import Tuple

import datasets
import huggingface_hub
import torch

from torch import Tensor
from torch import nn, optim
from torch.utils import data
from tqdm import tqdm

from src.model.metric import RollingMean, confusion

In [13]:
huggingface_hub.notebook_login()

In [14]:
ds = datasets.load_dataset("ikkiren/bigdata_ds")

In [15]:
class WindSpeedClassifier(nn.Module):
	def __init__(self, input_size):
		super(WindSpeedClassifier, self).__init__()
		self.model = nn.Sequential(
			nn.Linear(input_size, 16),
			nn.LeakyReLU(),
			nn.Dropout(0.3),

			nn.Linear(16, 16),
			nn.LeakyReLU(),
			nn.Dropout(0.3),

			nn.Linear(16, 1)
		)

	def forward(self, x):
		return self.model(x).squeeze()


In [16]:
class WindSpeedDataset(data.Dataset):
	INPUT_COL = ["AirNOW_O3", "CMAQ12KM_O3(ppb)", "CMAQ12KM_NO2(ppb)", "CMAQ12KM_CO(ppm)", "PBL(m)"]
	LABEL_COL = "WSPD10(m/s)"

	STRONG_WIND = 3
	FEATURES_NUM = len(INPUT_COL)

	def __init__(self, ds):
		self.ds = ds

	def __len__(self) -> int:
		return len(self.ds)

	def __getitem__(self, idx) -> Tuple[Tensor, Tensor]:
		sample = self.ds[idx]
		features = torch.tensor([sample[feature] for feature in WindSpeedDataset.INPUT_COL])
		label = torch.tensor(sample[WindSpeedDataset.LABEL_COL] > WindSpeedDataset.STRONG_WIND)
		return features, label


In [17]:
batch_size = 1024

train_set = WindSpeedDataset(ds["train"])
test_set = WindSpeedDataset(ds["test"])

train_loader = data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = data.DataLoader(test_set, batch_size=batch_size, shuffle=False)

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = WindSpeedClassifier(WindSpeedDataset.FEATURES_NUM).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [19]:
epochs = 2
best_loss = float("inf")


def open_log(epoch: int, train: bool):
	subdir = "train" if train else "test"
	subdir = os.path.join("log", subdir)
	os.makedirs(subdir, exist_ok=True)

	path = os.path.join(subdir, f"{epoch}.csv")
	log = open(path, "w")
	log.write("loss,accuracy,precision,recall\n")

	return log


def checkpoint(epoch: int):
	subdir = "checkpoint"
	os.makedirs(subdir, exist_ok=True)

	state = {
		"model": model.state_dict(),
		"optimizer": optimizer.state_dict()
	}

	path = os.path.join(subdir, f"{epoch}.pth")
	torch.save(state, path)


for epoch in range(epochs):
	# region train
	torch.set_grad_enabled(True)
	model.train()
	mean = RollingMean(100)

	log = open_log(epoch, True)
	bar = tqdm(train_loader, desc=f"Train {epoch + 1}")
	for inputs, labels in bar:
		inputs, labels = inputs.to(device), labels.to(device)

		optimizer.zero_grad()
		outputs = model(inputs)

		loss = criterion(outputs, labels.float())
		loss.backward()
		optimizer.step()

		loss = loss.item()
		conf = confusion(outputs, labels)

		acc, pre, rec = conf.accuracy, conf.precision, conf.recall
		log.write(f"{loss},{acc},{pre},{rec}\n")

		acc, pre, rec, loss = mean(acc, pre, rec, loss)
		bar.set_postfix(acc=acc, pre=pre, rec=rec, loss=loss)

	bar.close()
	log.close()
	checkpoint(epoch)
	# endregion train

	# region test
	torch.set_grad_enabled(False)
	model.eval()
	mean = RollingMean(100)

	log = open_log(epoch, False)
	bar = tqdm(test_loader, desc=f"Test  {epoch + 1}")
	for inputs, labels in bar:
		inputs, labels = inputs.to(device), labels.to(device)
		outputs = model(inputs)

		loss = criterion(outputs, labels.float()).item()
		conf = confusion(outputs, labels)

		acc, pre, rec = conf.accuracy, conf.precision, conf.recall
		log.write(f"{loss},{acc},{pre},{rec}\n")

		acc, pre, rec, loss = mean(acc, pre, rec, loss)
		bar.set_postfix(acc=acc, pre=pre, rec=rec, loss=loss)

	bar.close()
	log.close()
	# endregion test

100%|██████████| 7256/7256 [36:32<00:00,  3.31it/s, acc=0.768, loss=0.455, pre=0.653, rec=0.652]
100%|██████████| 1814/1814 [08:42<00:00,  3.47it/s, acc=0.772, loss=0.446, pre=0.656, rec=0.667]
 80%|████████  | 5838/7256 [28:41<06:58,  3.39it/s, acc=0.769, loss=0.453, pre=0.657, rec=0.644]


KeyboardInterrupt: 