Basic tutorials for basic PyTorch stuff: Cross-Entropy loss, Softmax, UpSample, cropping, concatenation.

Rough code to understand cross-entropy loss in torch. Btw, cross-entropy in torch does log-softmax by itself, so you can send the "raw" output directly. The log used is natural log. Log of base 2 could also be used, but shouldn't really matter: https://stats.stackexchange.com/questions/295174/difference-in-log-base-for-cross-entropy-calcuation

In [None]:
#predicted output
x = torch.tensor([[[0.41, 0.56], [0.69, 0.84]], [[0.57, 0.37], [0.29, 0.90]]])
x = x.unsqueeze(0) #The generic way to store image tensors is with a batch size too.

#You can either give the actual probabilities.
y = torch.tensor([[[0., 1.], [1., 1.]],[[1., 0.], [0., 0.]]]) #target output
y = y.unsqueeze(0) #Same reason as above.

#Or you can specify which class the pixels belong to. The first class has
#index 0, the second class has index 1, and so on. This is better computationally.
#If you are 100% sure in your target output of the classes of the pixel i.e. you use
#0 or 1, then it is better to use this.
y_class_index = torch.tensor([[1, 0], [0, 0]]) #target output
y_class_index = y_class_index.unsqueeze(0) #interpreted as [batch, height, width]

#reduction 'none' gives you the cross entropy value of each pixel.
entropy_1 = nn.CrossEntropyLoss(reduction='none')

#reduction 'sum' gives you the cross entropy value of the whole image.
entropy_2 = nn.CrossEntropyLoss(reduction='sum')

#The outputs are the same, but the "class_index" approach is computationally
#better. Look at the "NOTE" portion here:
#https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
#It makes sense that it is computationally better too. If the target corresponding
#to an output is 0, then that output doesn't contribute to the cross-entropy
#calculation (look at the formula for cross-entropy.) Therefore, you can
#ignore those. With the class index approach, you are essentially telling
#pytorch that your intention is to ignore the ignorable.
print(entropy_1(x, y))
print(entropy_1(x, y_class_index))

print(entropy_2(x, y))
print(entropy_2(x, y_class_index))

#Interesting that stuff in pytorch is so nicely structured to be straightforward
#for images. But maybe that's a bias because I'm only doing stuff for images.
#Maybe stuff is straightforward for other stuff too.

torch.Size([1, 2, 2, 2]) torch.Size([1, 2, 2, 2])
tensor([[[0.6163, 0.6027],
         [0.5130, 0.7236]]])
tensor([[[0.6163, 0.6027],
         [0.5130, 0.7236]]])
tensor(2.4556)
tensor(2.4556)


Rough code to understand how softmax works in torch.

In [None]:
x = torch.rand(2, 2, 2, 2)

#Softmax occurs between batches. For instance, the
#softmax for the value at the 1st pixel position in the 1st channel
#of the 1st batch would be between the values at the 1st pixel position
#in the 1st channel of all the 'N' batches.
soft_1 = nn.Softmax(dim=0)

#Softmax occurs between channels. For instance, the
#softmax for the value at the 1st pixel position in the 1st channel
#of some batch would be between the values at the 1st pixel position
#in all the channels of this batch.
soft_2 = nn.Softmax(dim=1)

#Softmax occurs amongst the values along the height of each channel.
soft_3 = nn.Softmax(dim=2)

#Softmax occurs amongst the values along the width of each channel.
soft_4 = nn.Softmax(dim=3)

print(soft_1(x))
print(soft_2(x))
print(soft_3(x))
print(soft_4(x))

#Btw, the values that you use to softmax sum up to 1 after softmax.

tensor([[[[0.4452, 0.3417],
          [0.5520, 0.3510]],

         [[0.4771, 0.5547],
          [0.5214, 0.5568]]],


        [[[0.5548, 0.6583],
          [0.4480, 0.6490]],

         [[0.5229, 0.4453],
          [0.4786, 0.4432]]]])
tensor([[[[0.4998, 0.3139],
          [0.5768, 0.4049]],

         [[0.5002, 0.6861],
          [0.4232, 0.5951]]],


        [[[0.5319, 0.5233],
          [0.5465, 0.6124]],

         [[0.4681, 0.4767],
          [0.4535, 0.3876]]]])
tensor([[[[0.4083, 0.4500],
          [0.5917, 0.5500]],

         [[0.4848, 0.5488],
          [0.5152, 0.4512]]],


        [[[0.5144, 0.4601],
          [0.4856, 0.5399]],

         [[0.5291, 0.5509],
          [0.4709, 0.4491]]]])
tensor([[[[0.4966, 0.5034],
          [0.5391, 0.4609]],

         [[0.3112, 0.6888],
          [0.3687, 0.6313]]],


        [[[0.3896, 0.6104],
          [0.3393, 0.6607]],

         [[0.3814, 0.6186],
          [0.4023, 0.5977]]]])


Rough code to understand how the upsample class works in torch.

In [None]:
img = Image.open(requests.get("https://c8.alamy.com/zooms/9/7df650603bfe4193bab024ee29da5461/2btr9xg.jpg", stream=True).raw)
tensor_trans = transforms.ToTensor() #converting the image to a tensor
img = tensor_trans(img) #converting the image to a tensor continued
img = img.unsqueeze(0) #adding a batch dimension to the image.
#Doing this because of the interpertation scheme of the shape values of tensors
#by the Upsample class. If our image's shape was (x, y, z), Upsample would interpret
#'x' as batch size, 'y' as number of channels and 'z' as the width. But obviously
#'x' is number of channels, 'y' is the height and 'z' is the width. But if the shape
#is (w, x, y, z), Upsample interprets 'w' as batch size, and 'x', 'y' and 'z' as wanted
#above. Adding batch size is fine; it doesn't remove generality. If you only
#want to upsample a single image, the batch size is just going to be 1. Btw, in
#the (x, y, z) case, Upsample only scales the width -- both "theoretically" and
#pratically, having a 3d tensor doesn't work for our purpose. For more clarification
#and information: https://pytorch.org/docs/stable/generated/torch.nn.Upsample.html

print(img.shape)
#by default, nearest is used. scale factor defines by what factor we want to increase
#the height/width. new_height_width = old_height_width * scale_factor.
upscale = nn.Upsample(scale_factor=2, mode='nearest')
img = upscale(img) #upsampling
print(img.shape)

torch.Size([1, 3, 447, 640])
torch.Size([1, 3, 894, 1280])


Rough code to understand cropping in torch.

In [None]:
x = torch.rand(1, 512, 64, 64)
#Look at the dotted lines in the paper Fig.1. I think that implies Center Crop.
crop = transforms.CenterCrop(56)
x = crop(x)
print(x.shape)


torch.Size([1, 512, 56, 56])


Rough code to understand how to do concat as required by the u-net paper in torch.

In [None]:
x = torch.ones(1, 2, 4, 4)
y = torch.zeros(1, 2, 4, 4)

#the dimension in torch.cat refer to the shape values above. Remember, shape values are
#(batch_size, num_of_channels, height, width).

#This does the concat on 'batch level' i.e. the new tensor will contain two image tensors x and y.
print(torch.cat((x, y), 0))
print(torch.cat((x, y), 0).shape)

#This does the concat on 'channel level' i.e. the channels of y will be added to x
#to create a new image tensor with more channels, specifically, new_number_of_
#channels = no_of_channels_x + no_of_channels_y
print(torch.cat((x, y), 1))
print(torch.cat((x, y), 1).shape)

#Concat on 'height level' i.e. rows of the 1st channel of y will be added to the rows
#of the 1st channel of x vertically. This is done similarly for each channel.
print(torch.cat((x,y), 2))
print(torch.cat((x, y), 2).shape)

#Concat on 'width level' i.e. rows of the 1st channel of y will be added to the rows
#of the 1st channel of x horizontally. This is done similarly for each channel.
print(torch.cat((x, y), 3))
print(torch.cat((x, y), 3).shape)

#Look at the outputs to understand more. It also really helps if you understand
#how pytorch tensors are structured i.e. what each pair of [] represents -- specifically
#in the context of images I guess in our case.



tensor([[[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]],

         [[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]],


        [[[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]]]])
torch.Size([2, 2, 4, 4])
tensor([[[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]],

         [[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]]]])
torch.Size([1, 4, 4, 4])
tensor([[[[1., 1., 1., 1.],
      