In [1]:
# import cv2
# import torch
# import torch.nn as nn
# import numpy as np

# # ----------------------------
# #  UNET Autoencoder Definition
# # ----------------------------
# import torch
# import torch.nn as nn
# import torchvision.models as models
# import torch.nn.functional as F

# class ResNet18AutoEncoder(nn.Module):
#     def __init__(self, pretrained=True):
#         super().__init__()

#         # -----------------
#         #  Encoder: ResNet-18
#         # -----------------
#         resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)

#         # Extract layers up to layer4 (final conv output)
#         self.enc1 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu)  # 64x128x128
#         self.pool = resnet.maxpool  # 64x64x64
#         self.enc2 = resnet.layer1  # 64x64x64
#         self.enc3 = resnet.layer2  # 128x32x32
#         self.enc4 = resnet.layer3  # 256x16x16
#         self.enc5 = resnet.layer4  # 512x8x8

#         # Bottleneck conv to 4x4
#         self.bottleneck = nn.Conv2d(512, 512, 3, stride=2, padding=1)  # 512x4x4

#         # -----------------
#         #  Decoder
#         # -----------------
#         self.up1 = nn.ConvTranspose2d(512, 512, 2, stride=2)  # 512x8x8
#         self.dec1 = nn.Conv2d(512 + 512, 512, 3, padding=1)  # skip from enc5

#         self.up2 = nn.ConvTranspose2d(512, 256, 2, stride=2)  # 256x16x16
#         self.dec2 = nn.Conv2d(256 + 256, 256, 3, padding=1)  # skip from enc4

#         self.up3 = nn.ConvTranspose2d(256, 128, 2, stride=2)  # 128x32x32
#         self.dec3 = nn.Conv2d(128, 128, 3, padding=1)

#         self.up4 = nn.ConvTranspose2d(128, 64, 2, stride=2)  # 64x64x64
#         self.dec4 = nn.Conv2d(64, 64, 3, padding=1)

#         self.up5 = nn.ConvTranspose2d(64, 64, 2, stride=2)  # 64x128x128
#         self.dec5 = nn.Conv2d(64, 64, 3, padding=1)

#         self.up6 = nn.ConvTranspose2d(64, 3, 2, stride=2)  # 3x256x256

#     def forward(self, x):
#         # Encoder
#         e1 = self.enc1(x)     # 64x128x128
#         e2 = self.enc2(self.pool(e1))  # 64x64x64
#         e3 = self.enc3(e2)    # 128x32x32
#         e4 = self.enc4(e3)    # 256x16x16
#         e5 = self.enc5(e4)    # 512x8x8

#         # Bottleneck
#         b = self.bottleneck(e5)  # 512x4x4

#         # Decoder with skip connections from e5 and e4
#         d1 = F.relu(self.dec1(torch.cat([self.up1(b), torch.zeros_like(e5)], dim=1)))  # 512x8x8
#         d2 = F.relu(self.dec2(torch.cat([self.up2(d1), torch.zeros_like(e4)], dim=1))) # 256x16x16
#         d3 = F.relu(self.dec3(self.up3(d2)))  # 128x32x32
#         d4 = F.relu(self.dec4(self.up4(d3)))  # 64x64x64
#         d5 = F.relu(self.dec5(self.up5(d4)))  # 64x128x128
#         out = torch.sigmoid(self.up6(d5))     # 3x256x256

#         return out


# def live_train_unet():
#     size = 256

#     model = ResNet18AutoEncoder().cuda()
#     model = torch.load("model_1.pth", weights_only=False)
#     opt = torch.optim.Adam(model.parameters(), lr=1e-4)
#     loss_fn = nn.MSELoss()

#     cap = cv2.VideoCapture(0)

#     while True:
#         ret, frame = cap.read()
#         if not ret:
#             break

#         # preprocess webcam frame
#         inp = cv2.resize(frame, (size, size))
#         inp_tensor = torch.tensor(inp).permute(2,0,1).float().cuda()/255.0
#         inp_tensor = inp_tensor.unsqueeze(0)

#         # ------- TRAIN ON THIS FRAME -------
#         model.train()
#         out = model(inp_tensor)
#         loss = loss_fn(out, inp_tensor)

#         opt.zero_grad()
#         loss.backward()
#         opt.step()
#         # -----------------------------------

#         # show reconstruction
#         model.eval()
#         with torch.no_grad():
#             decoded = model(inp_tensor)[0].cpu().permute(1,2,0).numpy()

#         decoded = (decoded * 255).astype(np.uint8)
#         combined = np.hstack((inp, decoded))

#         cv2.putText(combined, f"Loss: {loss.item():.4f}", (10,20),
#                     cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 1)

#         cv2.imshow("Training (Left) | Reconstructed (Right)", combined)

#         if cv2.waitKey(1) & 0xFF == 27:
#             torch.save(model, "model_1.pth")
#             break  # ESC ends

#     cap.release()
#     cv2.destroyAllWindows()


# # Run it
# # live_train_unet()


In [2]:
# live_train_unet()

In [63]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.functional as F

class ResNet18AutoEncoder(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()

        # -----------------
        # ENCODER (ResNet18)
        # -----------------
        resnet = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1)

        self.enc1 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu)  # 64 x 128
        self.pool = resnet.maxpool                                       # 64 x 64
        self.enc2 = resnet.layer1                                        # 64 x 64
        self.enc3 = resnet.layer2                                        # 128 x 32
        self.enc4 = resnet.layer3                                        # 256 x 16
        self.enc5 = resnet.layer4                                        # 512 x 8

        # -----------------
        # Bottleneck -> 4×4
        # -----------------
        self.bottleneck = nn.Conv2d(512, 512, 3, stride=2, padding=1)

        # -----------------
        # SKIP PROJECTIONS
        # (MAKE CHANNELS MATCH)
        # -----------------
        self.proj5 = nn.Conv2d(512, 512, 1)   # matches d1
        self.proj4 = nn.Conv2d(256, 256, 1)   # matches d2
        self.proj3 = nn.Conv2d(128, 128, 1)   # matches d3
        self.proj2 = nn.Conv2d(64, 64, 1)     # matches d4
        self.proj1 = nn.Conv2d(64, 64, 1)     # matches d5

        # -----------------
        # DECODER
        # -----------------
        self.up1 = nn.ConvTranspose2d(512, 512, 2, stride=2)          # 4→8
        self.dec1 = nn.Conv2d(512 + 512, 512, 3, padding=1)

        self.up2 = nn.ConvTranspose2d(512, 256, 2, stride=2)          # 8→16
        self.dec2 = nn.Conv2d(256 + 256, 256, 3, padding=1)

        self.up3 = nn.ConvTranspose2d(256, 128, 2, stride=2)          # 16→32
        self.dec3 = nn.Conv2d(128 + 128, 128, 3, padding=1)

        self.up4 = nn.ConvTranspose2d(128, 64, 2, stride=2)           # 32→64
        self.dec4 = nn.Conv2d(64 + 64, 64, 3, padding=1)

        self.up5 = nn.ConvTranspose2d(64, 64, 2, stride=2)            # 64→128
        self.dec5 = nn.Conv2d(64 + 64, 64, 3, padding=1)

        self.up6 = nn.ConvTranspose2d(64, 3, 2, stride=2)             # 128→256

    def forward(self, x):
        # --------- ENCODER ----------
        e1 = self.enc1(x)             # 64 x 128
        e2 = self.enc2(self.pool(e1)) # 64 x 64
        e3 = self.enc3(e2)            # 128 x 32
        e4 = self.enc4(e3)            # 256 x 16
        e5 = self.enc5(e4)            # 512 x 8

        # --------- BOTTLENECK ----------
        b = self.bottleneck(e5)       # 512 x 4

        # --------- DECODER ----------
        d1 = self.up1(b)
        d1 = torch.cat([d1, torch.zeros_like(e5)], dim=1)
        d1 = F.relu(self.dec1(d1))

        d2 = self.up2(d1)
        d2 = torch.cat([d2, torch.zeros_like(e4)], dim=1)
        d2 = F.relu(self.dec2(d2))

        d3 = self.up3(d2)
        d3 = torch.cat([d3, torch.zeros_like(e3)], dim=1)
        d3 = F.relu(self.dec3(d3))

        d4 = self.up4(d3)
        d4 = torch.cat([d4, torch.zeros_like(e2)], dim=1)
        d4 = F.relu(self.dec4(d4))

        d5 = self.up5(d4)
        d5 = torch.cat([d5, torch.zeros_like(e1)], dim=1)
        d5 = F.relu(self.dec5(d5))

        out = torch.sigmoid(self.up6(d5))
        return out


extra-low -> 8KB / frame
low -> 

In [68]:
import cv2
import torch
import torch.nn as nn
import numpy as np
import torchvision.models as models
import torch.nn.functional as F



def train_on_video():
    size = 256
    frame_itr = 0;

    save_frame = [1, 1300, 1400, 1500, 1600, 1700]
    type_ = "t"

    model = ResNet18AutoEncoder().cuda()

    # load model if exists
    try:
        model = torch.load("model_2.pth", weights_only=False)
        print("Loaded saved model.")
    except:
        print("No saved model found, starting fresh.")

    opt = torch.optim.Adam(model.parameters(), lr=1e-4)
    loss_fn = nn.MSELoss()

    # VIDEO INPUT INSTEAD OF WEBCAM
    cap = cv2.VideoCapture("movie3.mp4")

    if not cap.isOpened():
        print("Error: Cannot open movie.mp4")
        return

    while True:
        print(frame_itr)
        frame_itr += 1
        ret, frame = cap.read()
        if not ret:
            cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
            continue

        # preprocess frame
        #frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        inp = cv2.resize(frame, (size, size)).copy()
        inp_tensor = torch.tensor(inp).permute(2,0,1).float().cuda() / 255.0
        inp_tensor = inp_tensor.unsqueeze(0)

        # ----- TRAIN -----
        model.train()
        out = model(inp_tensor)
        loss = loss_fn(out, inp_tensor)

        opt.zero_grad()
        loss.backward()
        opt.step()
        with torch.no_grad():
            decoded = model(inp_tensor)[0].cpu().permute(1,2,0).numpy()

        decoded_uint8 = np.clip(decoded*255, 0, 255).astype(np.uint8)
        # assume inp and decoded_uint8 are both HxWx3 uint8 RGB images

        # convert decoded to BGR for OpenCV display
        decoded_bgr = cv2.cvtColor(decoded_uint8, cv2.COLOR_RGB2BGR)

        # create a horizontal separator bar (optional)
        bar_height = 5
        bar = np.zeros((bar_height, inp.shape[1], 3), dtype=np.uint8)  # black bar

        # stack vertically: input on top, bar, then decoded
        combined = np.vstack((inp, bar, decoded_uint8))

        # add optional labels on top of each frame
        top_height = 40
        top = np.zeros((top_height, combined.shape[1], 3), dtype=np.uint8)

        

        # cv2.putText(top, "ORIGINAL", (10, 28), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255,255,255), 2)
        # cv2.putText(top, "DECODED", (10, top.shape[1]//2 + 10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255,255,255), 2)

        # final frame with labels
        # final = np.vstack((top, combined))

        cv2.imshow("Vertical Stack: Input over Decoded", combined)

        if(frame_itr in save_frame):
            cv2.imwrite("frame_" + str(frame_itr) + "_" + type_ + ".jpg", decoded_uint8)

        key = cv2.waitKey(1)
        if key == 27:  # ESC
            print("Saving model...")
            torch.save(model, "model_2.pth")
            break

    cap.release()
    cv2.destroyAllWindows()


# Run it
# train_on_video()


In [69]:
train_on_video()

Loaded saved model.
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
Saving model...
