Skip to content

Commit

Permalink
add assert statements to ensure grad enabled
Browse files Browse the repository at this point in the history
  • Loading branch information
lillythomas committed Jan 11, 2024
1 parent bdf09cc commit 25ad052
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions notebooks/231220-custom-mae-embeddings-finetune.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1133,12 +1133,16 @@
"\n",
"\n",
" def training_step(self, batch, batch_idx):\n",
" torch.set_grad_enabled(True)\n",
" assert torch.is_grad_enabled()\n",
" assert all(p.requires_grad for p in self.parameters())\n",
" print(\"grad enabled\")\n",
" print(batch[\"pixels\"].shape, batch[\"labels\"].shape)\n",
" x, y = batch[\"pixels\"], batch[\"labels\"]\n",
" prediction = self.model(x)\n",
" y = y.squeeze()\n",
" y = torch.tensor(y.to(dtype=torch.long), requires_grad=False)\n",
" prediction = torch.tensor(prediction.to(dtype=torch.float32), requires_grad=True)\n",
" prediction = torch.tensor(prediction.float(), requires_grad=True)\n",
" targets_one_hot = torch.nn.functional.one_hot(y, 2) # ouputs in b x 1 x x h x w x c\n",
" targets_one_hot = targets_one_hot.squeeze()\n",
" targets_one_hot = targets_one_hot.permute(0,3,1,2) # ouputs in b x c x h x w\n",
Expand All @@ -1150,11 +1154,15 @@
" return loss\n",
"\n",
" def validation_step(self, batch, batch_idx):\n",
" torch.set_grad_enabled(True)\n",
" assert torch.is_grad_enabled()\n",
" assert all(p.requires_grad for p in self.parameters())\n",
" print(\"grad enabled\")\n",
" x, y = batch[\"pixels\"], batch[\"labels\"]\n",
" prediction = self.model(x)\n",
" y = y.squeeze()\n",
" y = torch.tensor(y.to(dtype=torch.long), requires_grad=False)\n",
" prediction = torch.tensor(prediction.to(dtype=torch.float32), requires_grad=True)\n",
" prediction = torch.tensor(prediction.float(), requires_grad=True)\n",
" targets_one_hot = torch.nn.functional.one_hot(y, 2) # ouputs in b x 1 x x h x w x c\n",
" targets_one_hot = targets_one_hot.squeeze()\n",
" targets_one_hot = targets_one_hot.permute(0,3,1,2) # ouputs in b x c x h x w\n",
Expand Down Expand Up @@ -1304,6 +1312,10 @@
" return self.model(x)\n",
"\n",
" def training_step(self, batch, batch_idx):\n",
" torch.set_grad_enabled(True)\n",
" assert torch.is_grad_enabled()\n",
" assert all(p.requires_grad for p in self.parameters())\n",
" print(\"grad enabled\")\n",
" print(batch[\"pixels\"].shape, batch[\"labels\"].shape)\n",
" x, y = batch[\"pixels\"], batch[\"labels\"]\n",
" batch[\"pixels\"] = batch[\"pixels\"].to(model_clay.device)\n",
Expand All @@ -1317,7 +1329,7 @@
" latent = rearrange(latent, \"b g h w d -> b (g d) h w\")\n",
" prediction = self.model(latent)\n",
" y = torch.tensor(y.to(dtype=torch.long), requires_grad=False)\n",
" prediction = torch.tensor(prediction.to(dtype=torch.float32), requires_grad=True)\n",
" prediction = torch.tensor(prediction.float(), requires_grad=True)\n",
" targets_one_hot = torch.nn.functional.one_hot(y, 2) # ouputs in b x 1 x x h x w x c\n",
" targets_one_hot = targets_one_hot.squeeze()\n",
" targets_one_hot = targets_one_hot.permute(0,3,1,2) # ouputs in b x c x h x w\n",
Expand All @@ -1329,6 +1341,10 @@
" return loss\n",
"\n",
" def validation_step(self, batch, batch_idx):\n",
" torch.set_grad_enabled(True)\n",
" assert torch.is_grad_enabled()\n",
" assert all(p.requires_grad for p in self.parameters())\n",
" print(\"grad enabled\")\n",
" x, y = batch[\"pixels\"], batch[\"labels\"]\n",
" batch[\"pixels\"] = batch[\"pixels\"].to(model_clay.device)\n",
" batch[\"timestep\"] = batch[\"timestep\"].to(model_clay.device)\n",
Expand All @@ -1342,7 +1358,7 @@
" latent = rearrange(latent, \"b g h w d -> b (g d) h w\")\n",
" prediction = self.model(latent)\n",
" y = torch.tensor(y.to(dtype=torch.long), requires_grad=False)\n",
" prediction = torch.tensor(prediction.to(dtype=torch.float32), requires_grad=True)\n",
" prediction = torch.tensor(prediction.float(), requires_grad=True)\n",
" targets_one_hot = torch.nn.functional.one_hot(y, 2) # ouputs in b x 1 x x h x w x c\n",
" targets_one_hot = targets_one_hot.squeeze()\n",
" targets_one_hot = targets_one_hot.permute(0,3,1,2) # ouputs in b x c x h x w\n",
Expand Down

0 comments on commit 25ad052

Please sign in to comment.