In [None]:
import torch
from pathlib import Path
import os
from torch import nn

self.galaxy_names = {
    0: "Edge-on without Bulge",
    1: "Unbarred Tight Spiral",
    2: "Edge-on with Bulge",
    3: "Merging",
    4: "In-between Round Smooth",
    5: "Barred Spiral",
    6: "Disturbed",
    7: "Unbarred Loose Spiral",
    8: "Cigar Shaped Smooth",
    9: "Round Smooth",
}

class MyDataset:
    def __init__(self, dataset_type):
        self.ds_type = dataset_type
        self.ds_dir = Path('/') / "kaggle" / "input" / "galaxies" / "dataset" / ds_type

        self.images = [f for f in self.dest_dir.iterdir()]
        print(f"Dataset of {len(self.images)} images loaded")

    def get_split(self, train_len):
        train_len = int(train_len*len(self))
        valid_len = len(self) - train_len
        return torch.utils.data.random_split(self, [train_len, valid_len])
        
    def __getitem__(self, index):
        f = self.images[index]
        img = np.load(f)
        # Only for train augment
        if self.ds_type == "train":
            pass
            #preprocess = transforms.Compose([
            #    transforms.Resize(256),
            #    transforms.CenterCrop(224),
            #    transforms.ToTensor(),
            #    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            #])
            #img = preprocess(img)

        # For train/validation idx is the id (e.g. 3678)
        # For test idx is galaxy (from 0 to 9)
        idx = int(f_name.stem)
        print("TODO: is it right???", idx)
        return (img, idx)

    def __len__(self):
        return len(self.images)


class Operation(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=3,
            stride=1,
            padding=1,
        )
        self.batchnorm = nn.BatchNorm2d(num_features=out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.batchnorm(x)
        x = self.relu(x)
        return x


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        in_channels = 3
        n_classes = 10
        self.down = nn.Sequential(
            Operation(3, 64),
            Operation(64, 64),
            nn.MaxPool2d(kernel_size=2, stride=2),
            Operation(64, 128),
            Operation(128, 128),
            nn.MaxPool2d(kernel_size=2, stride=2),
            Operation(128, 256),
            Operation(256, 256),
            Operation(256, 256),
            nn.MaxPool2d(kernel_size=2, stride=2),
            Operation(256, 512),
            Operation(512, 512),
            Operation(512, 512),
            nn.MaxPool2d(kernel_size=2, stride=2),
            Operation(512, 512),
            Operation(512, 512),
            Operation(512, 512),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.fully_connected = nn.Sequential(
            # 512*8*8: channels*image size maxpolled 5 times
            # 4096: is fixed by the authors
            nn.Linear(512*8*8, 4096),
            nn.ReLU(True), # In place True
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, n_classes),
        )

    def forward(self, x):
        x = self.down(x)
        # From [batch, 512, 8, 8] to [batch, 32768]
        x = x.reshape(x.shape[0], -1)
        x = self.fully_connected(x)
        return x

class Trainer:
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        # Old model checkpoint
        model_path = Path('/') / "kaggle" / "input" / "galaxies" / "model.ckpt"
        # New model checkpoint
        self.dir_weights = Path("/") / "kaggle" / "working"
        self.dir_weights.mkdir(exist_ok=True)
        # Load model checkpoint
        if os.path.isfile(model_path):
            self.model = torch.load(model_path)
            print("Using pre-trained weights")
        else:
            self.model = Net().to(self.device)
            print("Training from scratch")
        self.lr = 1e-5
        self.loss_fn = nn.CrossEntropyLoss()
        # TODO: TRY ADAM
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr) # weight_decay=1e-5
        self.epochs = 20
        self.dataset = MyDataset("train")
        # Dataset 80 / 20 split
        self.train_ds, self.valid_ds = self.dataset.get_split(0.8)
        self.train_dl = DataLoader(self.train_ds, batch_size=32, shuffle=True)
        self.valid_dl = DataLoader(self.valid_ds, batch_size=32)
        # Stats and Checkpoints directory
        self.losses = []
        self.accuracies = []


    def train(self):
        for epoch in range(self.epochs):
            correct_t, loss_t = self.train_epoch()
            correct_v, loss_v, _ = self.valid_epoch(self.valid_dl)
            # Save model
            torch.save(self.model, self.dir_weights / "model.ckpt")
            # Stats
            self.losses.append((loss_t, loss_v))
            self.accuracies.append((correct_t, correct_v))
            print(f"Epoch {epoch+1:>3}/{self.epochs} - Train accuracy {correct_t:5.2f}% loss {loss_t:8f} - Valid accuracy {correct_v:5.2f}% loss {loss_v:8f}")

    def train_epoch(self):
        size = len(self.train_dl.dataset)
        loss_value, correct = 0, 0
        for batch, (inputs, labels) in enumerate(self.train_dl):
            print(f"Train batch: {batch:>3}/{len(self.train_dl):<3}", end='\r')
            # Transfer to CPU or GPU
            inputs, labels = inputs.to(self.device), labels.to(self.device)
            # Set gradients to zero
            self.optimizer.zero_grad()
            # Get predictions
            pred = self.model(inputs)
            # Calculate loss
            loss = self.loss_fn(pred, labels)
            loss_value += loss.item()
            # Compare predictions with real values
            values = pred.argmax(1) == labels
            # Get how many are true
            correct += values.sum().item()
            # Backpropagation
            loss.backward()
            self.optimizer.step()
        accuracy = (correct / size) * 100
        loss_value /= size
        return accuracy, loss_value
    
    def valid_epoch(self, dataloader):
        """
        Validation function
        """
        # Swith layers
        self.model.eval()
        loss, correct = 0, 0
        size_by_label = {0:0, 1:0, 2:0, 3:0, 4:0, 5:0, 6:0, 7:0, 8:0, 9:0}
        correct_by_label = {0:0, 1:0, 2:0, 3:0, 4:0, 5:0, 6:0, 7:0, 8:0, 9:0}
        with torch.no_grad():
            # For each batch
            for batch, (inputs, labels) in enumerate(dataloader):
                print(f"Valid batch: {batch:>3}/{len(dataloader):<3}", end='\r')
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                pred = self.model(inputs)
                loss += self.loss_fn(pred, labels).item()
                # Compare predictions with real values
                for pred, real in zip(pred.argmax(1), labels):
                    # TODO: find a nicer way to get prediction from tensor
                    galaxy_num = int(real)
                    size_by_label[galaxy_num] += 1
                    if pred == real:
                        # Compute per label accuracy
                        correct_by_label[galaxy_num] += 1
                        correct += 1
        size = len(dataloader.dataset)
        accuracy = (correct / size) * 100
        loss /= size
        for key, value in correct_by_label.items():
            correct_by_label[key] /= size_by_label[key] if size_by_label[key] else 1
            # Transform in percentage
            correct_by_label[key] *= 100
        return accuracy, loss, correct_by_label

    def test_epoch(self, dataloader):
        """
        Test function
        """
        predictions = []
        self.model.eval()
        with torch.no_grad():
            for batch, (inputs, f_names) in enumerate(dataloader):
                print(f"Test batch: {batch:>3}/{len(dataloader):<3}", end='\r')
                inputs = inputs.to(self.device)
                pred = self.model(inputs)
                for pred, f_name in zip(pred.argmax(1), f_names):
                    # TODO: find a nicer way to get prediction from tensor
                    pred = int(pred)
                    f_name = int(f_name)
                    predictions.append((pred, f_name))
        return predictions

trainer = Trainer()
print(trainer.galaxy_names)

In [None]:
# Plot images by type
images_dir = trainer.dataset.dest_dir
img_by_type = {0:[], 1:[], 2:[], 3:[], 4:[], 5:[], 6:[], 7:[], 8:[], 9:[]}
for img_name in os.listdir(images_dir):
    # e.g. 3
    img_type = int(img_name[:1])
    img_by_type[img_type].append(img_name)

rows = 2 # 1 removes the dimention
fig, ax = plt.subplots(rows, 10, figsize=(40, 4*rows))
fig.suptitle('Image examples')
for i in range(10):
    for j in range(rows):
        img = img_by_type[i][j]
        img = np.load(images_dir / img)
        img = np.transpose(img, (2, 1, 0))
        title = i
        ax[j][i].set_title(title)
        ax[j][i].imshow(img)

In [None]:
#trainer.train()

In [None]:
# Plot loss and accuracy
fig, (ax1, ax2) = plt.subplots(2, figsize=(10,11))
# Get image sample
img_size = trainer.train_ds[0][0].shape
fig.suptitle(f'VGG16 - Image size: {img_size} - LR: {trainer.lr}')
epochs = list(range(1, trainer.epochs + 1))
#epochs = list(range(1, 12))
train, test = ax1.plot(epochs, trainer.losses, label='test')
ax1.set_title('Loss')
ax1.legend((train, test), ("Train", "Test"))
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Loss')

train, test = ax2.plot(epochs, trainer.accuracies, label='test2')
ax2.set_title('Accuracy')
ax2.legend((train, test), ("Train", "Test"))
ax2.yaxis.set_major_formatter(mtick.PercentFormatter())
ax2.set_xlabel('Epochs')
ax2.set_ylabel('Accuracy')
plt.show()

In [None]:
# Print train and validation perfornces by galaxy
def print_stats(acc, loss, accuracy_by_label):
    print(f"Accuracy {acc:.2f}%    ")
    print(f"Loss {loss:8f}")
    for key, value in accuracy_by_label.items():
        galaxy_name = trainer.dataset.galaxy_names[key]
        print(f"- {value:5.1f}% <- {galaxy_name}")

print(f"--- Train dataset ---")
acc, loss, accuracy_by_label = trainer.valid_epoch(trainer.train_dl)
print_stats(acc, loss, accuracy_by_label)
print(f"--- Valid dataset ---")
acc, loss, accuracy_by_label = trainer.valid_epoch(trainer.valid_dl)
print_stats(acc, loss, accuracy_by_label)

In [None]:
dataset = MyDataset("test")
test_dl = DataLoader(dataset, batch_size=32)
predictions = trainer.test_epoch(test_dl)
predictions_new = []
for pred, f_name in predictions:
    # From 5 to Barred Spiral
    pred = trainer.dataset.galaxy_names[pred]
    predictions_new.append((f_name, pred))

# Sort by filename
predictions_new = sorted(predictions_new, key=lambda x: x[0])

with open('output.csv', 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerows(predictions_new)
print("\nDone!")