In [28]:
import os
import os.path as osp
import sys
sys.path.append("..")

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import core
import utils

In [29]:
os.environ["TORCH_HOME"] = "."

In [30]:
model = torch.hub.load('hub/deeplabv3', source="local", model='deeplabv3_resnet50', pretrained=True)

In [31]:
print(model)
print(model.aux_classifier)

DeepLabV3(
  (backbone): IntermediateLayerGetter(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Se

output['out'] contains the semantic masks, and output['aux'] contains the auxiliary loss values per-pixel. In inference mode, output['aux'] is not useful. So, output['out'] is of shape (N, 21, H, W)

In [32]:
model.aux_classifier[-1] = nn.Conv2d(256, 9, kernel_size=(1, 1), stride=(1, 1))
model.classifier[-1] = nn.Conv2d(256, 9, kernel_size=(1, 1), stride=(1, 1))
stretch = nn.Upsample(scale_factor=(4, 1), mode='bilinear', align_corners=False)
squeeze = nn.Conv2d(5, 3, kernel_size=3, stride=1, padding=1)

In [33]:
import torch
import torch.nn as nn

# 创建一个 64x512 的随机特征图
x = torch.randn(4, 5, 64, 512)  # (batch_size, channels, height, width)

x = stretch(x)
x = squeeze(x)

model.train()
x = model(x)
out, aux = x["out"], x["aux"]

out = F.upsample(out, size=(64, 512))
aux = F.upsample(aux, size=(64, 512))

y = torch.randint(0, 9, (4, 64, 512))
criterion = nn.CrossEntropyLoss()
loss1 = criterion(out, y)
loss2 = criterion(aux, y)



In [34]:
class DeepLabV3(nn.Module):
    def __init__(self, num_cls: int, stretch_shape: tuple, in_channels: int):
        super(DeepLabV3, self).__init__()
        
        self.stretch = nn.Upsample(size=stretch_shape, mode='bilinear', align_corners=True)
        self.squeeze = nn.Conv2d(in_channels, 3, kernel_size=1, stride=1, padding=0)

        self.deeplab = torch.hub.load('hub/deeplabv3', source="local", model='deeplabv3_resnet50', pretrained=True)
        self.deeplab.classifier[-1] = nn.Conv2d(256, num_cls, kernel_size=(1, 1), stride=(1, 1))
        self.deeplab.aux_classifier[-1] = nn.Conv2d(256, num_cls, kernel_size=(1, 1), stride=(1, 1))
    
    def forward(self, x: torch.Tensor):
        original_shape = x.shape[2:]
        x = self.stretch(x)
        x = self.squeeze(x)
        x = self.deeplab(x)
        return (
            F.upsample(x["out"], size=original_shape, mode="bilinear", align_corners=True),
            F.upsample(x["aux"], size=original_shape, mode="bilinear", align_corners=True)
        )

In [35]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dataloader = torch.utils.data.DataLoader(
    dataset=core.dataset.KITTISpherical(
        "../data", "train",
        core.readconfyaml.read("../conf/data.yaml")
    ),
    batch_size=4,
    shuffle=True,
    num_workers=2
)

model = DeepLabV3(
    num_cls=9,
    stretch_shape=(256, 512),
    in_channels=5
).to(device)

optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
criterion = torch.nn.CrossEntropyLoss(
    weight=torch.tensor([0.125, 1.2, 1.1, 0.5, 1.5, 1, 1, 1.2, 1.2], dtype=torch.float32),
    ignore_index=-1
).to(device)

In [36]:
for epoch in range(50):
    for batch_idx, (fmap, gdth) in enumerate(train_dataloader):
        fmap = fmap.to(device).float()
        gdth = gdth.to(device).long()

        pred_out, pred_aux = model(fmap)

        loss_out = criterion(pred_out, gdth)
        loss_aux = criterion(pred_aux, gdth)
        loss = loss_out + 0.5 * loss_aux
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(loss.item())
    scheduler.step()


3.4000120162963867
3.1682510375976562
3.2154836654663086
2.8017578125
2.423081159591675
2.658344030380249
2.3561043739318848
1.947843313217163
2.5080337524414062
2.198582887649536
1.610268235206604
1.7348902225494385
1.4596961736679077
1.6063477993011475
1.7047138214111328
1.2857706546783447
1.229272484779358
2.444493532180786
2.2974212169647217
1.560075044631958
1.1718169450759888
1.6614658832550049
1.164381742477417
1.9247126579284668
1.771458625793457
1.1350548267364502
1.1128251552581787
0.8781954050064087
0.7269815802574158
1.0978459119796753
1.7143548727035522
1.290608286857605
0.9068481922149658
0.9184226989746094
1.078495740890503
1.1456791162490845
0.9442987442016602
2.0510640144348145
0.8484172821044922
1.281785011291504
1.5556786060333252
0.9312343597412109
1.6731020212173462
0.7015662789344788
0.8422929644584656
1.0853056907653809
0.6909208297729492
0.9691123962402344
1.0129384994506836
0.5143333673477173
0.8398679494857788
0.8047423362731934
1.535339593887329
0.69990158081

KeyboardInterrupt: 