In [1]:
import torch
from torch.utils.data import Dataset

from transformers import SegformerImageProcessor
from transformers.image_utils import ChannelDimension

In [42]:
class SegFormerDeforestationDataset(Dataset):
    """Forcasting Deforestation using SegFormer Dataset"""

    def __init__(self, dataset, feature_extractor):
        """
        Args:
            dataset: "train", "val", "test
        """
        self.features = torch.load(f'../data/processed/{dataset}_features.pt')
        self.labels = torch.load(f'../data/processed/{dataset}_labels.pt')
        self.feature_extractor = feature_extractor

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        # randomly crop + pad both image and segmentation map to same size
        features = self.features[idx][:,:,-3:].to(torch.uint8) - 1 # only consider last 3 years & convert labels
        labels = self.labels[idx][:,:,0].squeeze().to(torch.uint8) - 1 # predict next year & convert labels
        encoded_inputs = self.feature_extractor(features, labels, return_tensors="pt")

        for k,v in encoded_inputs.items():
          encoded_inputs[k].squeeze_() # remove batch dimension


        return encoded_inputs

def get_data_loaders():
    feature_extractor = SegformerImageProcessor()

    # load datasets
    train_dataset = SegFormerDeforestationDataset("train", feature_extractor)
    val_dataset = SegFormerDeforestationDataset("val", feature_extractor)
    test_dataset = SegFormerDeforestationDataset("test", feature_extractor)

    # create the train, val and test dataloaders
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

    return train_loader, val_loader, test_loader

In [43]:
train_loader, val_loader, test_loader = get_data_loaders()

In [45]:
batch = next(iter(train_loader))
print(batch["pixel_values"].shape)
print(batch["labels"].shape)

print(batch["pixel_values"][0].shape)
print(batch["labels"][0].shape)
print(batch["labels"][0])
print(batch["labels"][0].squeeze().unique())

torch.Size([32, 3, 512, 512])
torch.Size([32, 512, 512])


In [4]:
from transformers import SegformerForSemanticSegmentation

# define model
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0",
                                                         num_labels=2
)

Some weights of the model checkpoint at nvidia/mit-b0 were not used when initializing SegformerForSemanticSegmentation: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.linear_fuse.weight', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear

In [5]:
import evaluate
metric = evaluate.load("mean_iou")

In [55]:
import torch
from torch import nn
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm

# define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00006)
# move model to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

model.train()
for epoch in range(200):  # loop over the dataset multiple times
   print("Epoch:", epoch)

   predictions = []
   references = []
   for idx, batch in enumerate(tqdm(train_loader)):
        # get the inputs;
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(pixel_values=pixel_values, labels=labels)
        loss, logits = outputs.loss, outputs.logits
        
        loss.backward()
        optimizer.step()

        # evaluate
        with torch.no_grad():
          upsampled_logits = nn.functional.interpolate(logits, size=labels.shape[-2:], mode="bilinear", align_corners=False)
          predicted = upsampled_logits.argmax(dim=1)
          
          # note that the metric expects predictions + labels as numpy arrays
          for prediction in predicted.detach().cpu().numpy():
            predictions.append(prediction)
          for reference in labels.detach().cpu().numpy():
            references.append(reference)

        # let's print loss and metrics every 100 batches
        if idx % 1 == 0:
          metrics = metric.compute(predictions=predictions,
                                   references=references,
                                   num_labels=2, 
                                   ignore_index=255,
                                   reduce_labels=False, # we've already reduced the labels before)
          )

          print("Loss:", loss.item())
          print("Mean_iou:", metrics["mean_iou"])
          print("Mean accuracy:", metrics["mean_accuracy"])

Epoch: 0


  0%|          | 0/28 [00:00<?, ?it/s]

: 

: 