In [None]:
from utils.data_utils import SpacecraftDataset, train_transforms, valid_transforms
from utils.vis_utils import predict_mask, display_random_examples
from utils.seg_utils import IoU, SegModel
import os
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import albumentations as A
import cv2
import matplotlib.pyplot as plt
import wandb
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
)
from torch.utils.data import DataLoader, Dataset
from albumentations.pytorch.transforms import ToTensorV2
from pathlib import Path
from tqdm import tqdm

In [None]:
DATA = Path("data")
WORKING=Path("working")

BATCH_SIZE=32

**DataSet**

Here we create a torch dataset for our data


In [None]:
valid_transforms = A.Compose([A.Resize(520, 520), A.Normalize(), ToTensorV2()])

train_dataset = SpacecraftDataset(transforms=train_transforms)
valid_dataset = SpacecraftDataset(transforms=valid_transforms, split="val")

In [None]:
display_random_examples(train_dataset)

In [None]:
model = torchvision.models.segmentation.deeplabv3_mobilenet_v3_large(weights="DEFAULT")
model.classifier[4] = nn.Conv2d(256, 4, kernel_size=(1, 1), stride=(1, 1))
model.aux_classifier[4] = nn.Conv2d(10, 4, kernel_size=(1, 1), stride=(1, 1))
for params in model.parameters():
    params.requires_grad = True

all_params = sum(p.numel() for p in model.parameters())
train_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"No of parameters: {all_params}")
print(f"No of trainable parameters: {train_params}")

In [None]:
# maybe we should log everything to wandb but for now jus
wandb.login()
wandb_logger = WandbLogger(project="Spacecrafts")

# The model
pl_model = SegModel(model, BATCH_SIZE, 0.02, train_dataset, valid_dataset)
checkpoint_callback = ModelCheckpoint(
    save_weights_only=True, monitor="val_iou", mode="max"
)

# Pytorch Lightning Trainer
trainer = pl.Trainer(
    max_epochs=500,
    logger=wandb_logger,
    log_every_n_steps=10,
    callbacks=[
        checkpoint_callback,
        LearningRateMonitor("epoch"),
        EarlyStopping(monitor="val_iou", mode="max", patience=15),
    ],
)

# Batch size and initial learning rate estimation
tuner = pl.tuner.Tuner(trainer)
tuner.scale_batch_size(pl_model)
tuner.lr_find(pl_model)

# Model training
trainer.fit(pl_model)

# Saving the weights
torch.save(pl_model.state_dict(), WORKING / "spacecrafts.pt")

wandb.finish()