In [15]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib

import torch
from torch import nn
from torchvision import transforms
import torchvision
preprocess = transforms.Compose([
    transforms.Normalize(
        mean=np.array([0.1086, 0.0934, 0.0711]),
        std=np.array([0.1472, 0.123, 0.1032]))
])

import h5py

SEED = 1337
torch.manual_seed(SEED)
np.random.seed(SEED)
torch.use_deterministic_algorithms(True)

from model import StarChartModel, StarData

In [16]:
with h5py.File("train.1.h5", "r") as F:
    train_images1 = np.array(F["X"])
    train_labels1 = np.array(F["y"])
with h5py.File("train.2.h5", "r") as F:
    train_images2 = np.array(F["X"])
    train_labels2 = np.array(F["y"])
train_images = np.concatenate((train_images1, train_images2))
train_labels = np.concatenate((train_labels1, train_labels2))

In [17]:
batch_size = 32
train_set = StarData((train_images.transpose(0,3,1,2)/255), train_labels, transform=preprocess)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
    shuffle=True, num_workers=1, pin_memory=True)

In [18]:
flag = np.zeros((33,33))

In [19]:
model = StarChartModel([0.3,0.15])
model.train()

loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 0.0005)

In [20]:
for epoch in range(1):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs = data[0]
        labels = data[1]

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = loss_func(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 32 == 31:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 32:.3f}')
            running_loss = 0.0
PATH = './Trained_Model.pt'
torch.save(model.state_dict(), PATH)
model.load_state_dict(torch.load('./Trained_Model.pt'))
# print('Finished Training')

[1,    32] loss: 0.691
[1,    64] loss: 0.667
[1,    96] loss: 0.695
[1,   128] loss: 0.666
[1,   160] loss: 0.666
[1,   192] loss: 0.658
[1,   224] loss: 0.644
[1,   256] loss: 0.616
[1,   288] loss: 0.604
[1,   320] loss: 0.604


<All keys matched successfully>

In [21]:
model.eval()
class TestData(torch.utils.data.Dataset):
    def __init__(self, set_X, transform = None):
        self.transform = transform
        self.X = set_X
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        x = self.X[idx]
        if self.transform:
            return self.transform(torch.tensor(x).float())
        return torch.tensor(x).float()

with h5py.File("test.h5", "r") as F:
    data = np.array(F["X"])
    data = TestData((data[0][0].transpose(0,3,1,2)/255), transform=preprocess)
    test_loader = torch.utils.data.DataLoader(data, batch_size=batch_size,
    shuffle=False, num_workers=1, pin_memory=True)
    for i, input in enumerate(test_loader,0):
        x= i % 33
        y= i //33
        output = model(input[0])
        output = torch.argmax(output, dim=1)
        if (output[0].item() == 1 and output[1].item() == 1):
            flag[x][y] = 100
    plt.imshow(flag, cmap="gray")
    plt.show()
        