In [34]:
import torch
import torch.nn as nn

class UNet1D(nn.Module):
    def __init__(self, in_channels, out_channels, depth=2, num_layers=2):
        super(UNet1D, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_layers = num_layers
        self.depth = depth
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()
        self.num_start_filters = 32

        self._create_unet(self.in_channels, self.num_start_filters)
        self.bottleneck = nn.Sequential(
            nn.Conv1d(self.num_start_filters * 2 ** (self.depth - 1), 2 * self.num_start_filters * 2 ** (self.depth - 1), kernel_size=1, padding=0),
            nn.ReLU()
        )
        self.logits = nn.Conv1d(self.num_start_filters, self.out_channels, 1, 1)


    def _create_encoder_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU()
        )

    def _create_decoder_block(self, in_channels, out_channels):
        return nn.ModuleList([nn.ConvTranspose1d(in_channels, in_channels//2, kernel_size=2, stride=2),
            nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU()])

    def _create_unet(self, in_channels, out_channels):
        for _ in range(self.depth):
            self.encoder.append(self._create_encoder_block(in_channels, out_channels))
            in_channels, out_channels = out_channels, out_channels*2

        out_channels = in_channels
        in_channels = in_channels * 2
        for _ in range(self.depth):
            self.decoder.append(self._create_decoder_block(in_channels, out_channels))
            in_channels, out_channels = out_channels, out_channels//2

    def forward(self, x):
        encoded = []
        for enc in self.encoder:
            x = enc(x)
            encoded.append(x)
            x = nn.MaxPool1d(kernel_size=2, stride=2)(x)

        x = self.bottleneck(x)  # Bottleneck layer

        for dec in self.decoder:
            ## Adding input with encoder concatenation
            enc_output = encoded.pop()
            x = dec[0](x)
            x = torch.cat((enc_output, x), dim=1)
            x = dec[1](x)
            x = dec[2](x)
        return self.logits(x)

input_channels = 4  
output_channels = 1 
depth = 2
num_layers = 2

model = UNet1D(input_channels, output_channels, depth, num_layers)
print(model)

UNet1D(
  (encoder): ModuleList(
    (0): Sequential(
      (0): Conv1d(4, 32, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): ReLU()
    )
    (1): Sequential(
      (0): Conv1d(32, 64, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): ReLU()
    )
  )
  (decoder): ModuleList(
    (0): ModuleList(
      (0): ConvTranspose1d(128, 64, kernel_size=(2,), stride=(2,))
      (1): Conv1d(128, 64, kernel_size=(3,), stride=(1,), padding=(1,))
      (2): ReLU()
    )
    (1): ModuleList(
      (0): ConvTranspose1d(64, 32, kernel_size=(2,), stride=(2,))
      (1): Conv1d(64, 32, kernel_size=(3,), stride=(1,), padding=(1,))
      (2): ReLU()
    )
  )
  (bottleneck): Sequential(
    (0): Conv1d(64, 128, kernel_size=(1,), stride=(1,))
    (1): ReLU()
  )
  (logits): Conv1d(32, 1, kernel_size=(1,), stride=(1,))
)


In [35]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

In [36]:
print(count_parameters(model))

66369


In [37]:
# Generate synthetic data
import numpy as np
num_superpixels = 3000
num_features = 4
synthetic_data = np.random.rand(num_superpixels, num_features)
synthetic_data = torch.tensor(synthetic_data, dtype=torch.float32)

#Reshape
synthetic_data = synthetic_data.unsqueeze(0).transpose(1, 2)

# Pass the synthetic data through the U-Net model
with torch.no_grad():
    output = model(synthetic_data)

print("Input shape:", synthetic_data.shape)
print("Output shape:", output.shape)

Input shape: torch.Size([1, 4, 3000])
Output shape: torch.Size([1, 1, 3000])


In [None]:
class CloudSegmentationModel(nn.Module):
    def __init__(self):
        super(CloudSegmentationModel, self).__init__()
        self.unet = UNet1D(in_channels=4, out_channels=1)
        
    def forward(self, x):
        return torch.sigmoid(self.unet(x))

In [8]:
model = CloudSegmentationModel()

with torch.no_grad():
    output = model(synthetic_data)

print("Input shape:", synthetic_data.shape)
print("Output shape:", output.shape)

Input shape: torch.Size([1, 4, 3000])
Output shape: torch.Size([1, 1, 3000])


In [38]:
class DiceBCELoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceBCELoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        inputs = F.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
        Dice_BCE = BCE + dice_loss
        
        return Dice_BCE

In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
segmentationModel = CloudSegmentationModel().to(device)

train_loader = None # Train loader for our dataset
test_loader = None # Test loader for our dataset

# loss function and optimizer
criterion = DiceBCELoss()
optimizer = torch.optim.Adam(segmentationModel.parameters(), lr=0.001)

# Training loop
num_epochs = 10
segmentationModel.train() 

for epoch in range(num_epochs):
    running_loss = 0

    for superpixel, label in train_loader:
        superpixel = superpixel.to(device)
        label = label.to(device)

        # Forward pass
        output = segmentationModel(superpixel)
        loss = criterion(output, label)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    # Print epoch statistics
    epoch_loss = running_loss / len(train_loader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}')

    # Evaluation
    segmentationModel.eval()
    test_loss = 0

    with torch.no_grad():
        for superpixel, label in test_loader:
            superpixel = superpixel.to(device)
            label = label.to(device)

            output = segmentationModel(superpixel)
            test_loss += criterion(output, label).item()

    test_loss /= len(test_loader)
    print(f'Test Loss: {test_loss:.4f}')


TypeError: 'NoneType' object is not iterable