Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 21, 2023
1 parent 11b1952 commit e65b2b0
Show file tree
Hide file tree
Showing 3 changed files with 261 additions and 153 deletions.
106 changes: 54 additions & 52 deletions notebooks/231220-custom-mae-embeddings-finetune.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"from __future__ import annotations"
]
},
Expand All @@ -17,9 +16,7 @@
"id": "2d5517a8-7ff0-4340-90e3-3d37cf6ab11b",
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path"
]
"source": []
},
{
"cell_type": "code",
Expand All @@ -38,21 +35,14 @@
"metadata": {},
"outputs": [],
"source": [
"from src.model_clay_eval import CLAYModule\n",
"import src.datamodule\n",
"#from src.datamodule import ClayDataset, ClayDataModule\n",
"from src.datamodule_eval_local import ClayDataset, ClayDataModule\n",
"import pandas as pd\n",
"import random\n",
"import matplotlib.pyplot as plt\n",
"from torch.utils.data import DataLoader\n",
"import numpy as np\n",
"import einops\n",
"from sklearn.decomposition import PCA\n",
"from sklearn.metrics.pairwise import cosine_similarity\n",
"import rasterio as rio\n",
"from einops import rearrange, reduce\n",
"import torch"
"import torch\n",
"from einops import rearrange\n",
"\n",
"# from src.datamodule import ClayDataset, ClayDataModule\n",
"from src.datamodule_eval_local import ClayDataModule\n",
"from src.model_clay_eval import CLAYModule"
]
},
{
Expand All @@ -62,8 +52,10 @@
"metadata": {},
"outputs": [],
"source": [
"model = CLAYModule.load_from_checkpoint(\"../clay-small-70MT-1100T-10E.ckpt\", mask_ratio=0.)\n",
"model.eval();"
"model = CLAYModule.load_from_checkpoint(\n",
" \"../clay-small-70MT-1100T-10E.ckpt\", mask_ratio=0.0\n",
")\n",
"model.eval()"
]
},
{
Expand Down Expand Up @@ -524,11 +516,11 @@
"from einops import rearrange\n",
"\n",
"embeddings = emb[0]\n",
"embeddings = embeddings[:,:-2,:]\n",
"embeddings = embeddings[:, :-2, :]\n",
"latent = rearrange(embeddings, \"b (g l) d -> b g l d\", g=6)\n",
"latent = rearrange(latent, \"b g (h w) d -> b g h w d\", h=16, w=16)\n",
"latent = rearrange(latent, \"b g h w d -> b (g d) h w\")\n",
"print(latent.shape)\n"
"print(latent.shape)"
]
},
{
Expand Down Expand Up @@ -590,7 +582,7 @@
}
],
"source": [
"plt.imshow(batch[\"pixels\"][0].permute(1,2,0)[:,:,1].detach().numpy(), cmap=\"bwr\")"
"plt.imshow(batch[\"pixels\"][0].permute(1, 2, 0)[:, :, 1].detach().numpy(), cmap=\"bwr\")"
]
},
{
Expand All @@ -611,9 +603,9 @@
}
],
"source": [
"fig, axs = plt.subplots(1,10,figsize=(10,5))\n",
"for i,ax in enumerate(axs.flatten()):\n",
" ax.imshow(latent[0][i+10].detach().numpy(), cmap=\"bwr\")"
"fig, axs = plt.subplots(1, 10, figsize=(10, 5))\n",
"for i, ax in enumerate(axs.flatten()):\n",
" ax.imshow(latent[0][i + 10].detach().numpy(), cmap=\"bwr\")"
]
},
{
Expand All @@ -624,8 +616,8 @@
"outputs": [],
"source": [
"encoder = torch.nn.Sequential(\n",
" torch.nn.Conv2d(4608, 64, kernel_size=1, padding=0),\n",
" )\n",
" torch.nn.Conv2d(4608, 64, kernel_size=1, padding=0),\n",
")\n",
"\n",
"decoder = torch.nn.Sequential(\n",
" torch.nn.Conv2d(4608, 64, kernel_size=1, padding=0),\n",
Expand All @@ -640,7 +632,7 @@
" torch.nn.ReLU(inplace=True),\n",
" torch.nn.Upsample(scale_factor=2),\n",
" torch.nn.ConvTranspose2d(8, 1, kernel_size=3, padding=1),\n",
" torch.nn.Upsample(scale_factor=2)\n",
" torch.nn.Upsample(scale_factor=2),\n",
")\n",
"\n",
"\n",
Expand Down Expand Up @@ -953,13 +945,14 @@
}
],
"source": [
"from pytorch_lightning import LightningModule, Trainer\n",
"from einops import rearrange\n",
"from pytorch_lightning import LightningModule, Trainer\n",
"\n",
"\n",
"class UNet(torch.nn.Module):\n",
" def __init__(self, in_channels, out_channels):\n",
" super().__init__() \n",
" \n",
" super().__init__()\n",
"\n",
" self.decoder = torch.nn.Sequential(\n",
" torch.nn.Conv2d(4608, 64, kernel_size=1, padding=0),\n",
" torch.nn.Upsample(scale_factor=2),\n",
Expand All @@ -973,11 +966,10 @@
" torch.nn.ReLU(inplace=True),\n",
" torch.nn.Upsample(scale_factor=2),\n",
" torch.nn.ConvTranspose2d(8, 1, kernel_size=3, padding=1),\n",
" torch.nn.Upsample(scale_factor=2)\n",
" torch.nn.Upsample(scale_factor=2),\n",
" )\n",
"\n",
"\n",
" def forward(self,x):\n",
" def forward(self, x):\n",
" x = self.decoder(x)\n",
" return x\n",
"\n",
Expand All @@ -987,7 +979,7 @@
" super().__init__()\n",
" self.model = model\n",
" self.datamodule = datamodule\n",
" \n",
"\n",
" def forward(self, x):\n",
" return self.model(x)\n",
"\n",
Expand All @@ -999,14 +991,16 @@
" batch[\"latlon\"] = batch[\"latlon\"].to(model_clay.device)\n",
" emb = model_clay.model.encoder(batch)\n",
" embeddings = emb[0]\n",
" embeddings = embeddings[:,:-2,:]\n",
" embeddings = embeddings[:, :-2, :]\n",
" latent = rearrange(embeddings, \"b (g l) d -> b g l d\", g=6)\n",
" latent = rearrange(latent, \"b g (h w) d -> b g h w d\", h=16, w=16)\n",
" latent = rearrange(latent, \"b g h w d -> b (g d) h w\")\n",
" prediction = self.model(latent)\n",
" print(\"Prediction shape:\", prediction.shape)\n",
" print(\"Label shape:\", y.shape)\n",
" loss = torch.nn.functional.binary_cross_entropy_with_logits(prediction.to(dtype=torch.float32), y)\n",
" loss = torch.nn.functional.binary_cross_entropy_with_logits(\n",
" prediction.to(dtype=torch.float32), y\n",
" )\n",
" loss = torch.tensor(loss, requires_grad=True)\n",
" self.log(\"train_loss\", loss)\n",
" print(\"train_loss\", loss)\n",
Expand All @@ -1019,7 +1013,7 @@
" batch[\"latlon\"] = batch[\"latlon\"].to(model_clay.device)\n",
" emb = model_clay.model.encoder(batch)\n",
" embeddings = emb[0]\n",
" embeddings = embeddings[:,:-2,:]\n",
" embeddings = embeddings[:, :-2, :]\n",
" latent = rearrange(embeddings, \"b (g l) d -> b g l d\", g=6)\n",
" latent = rearrange(latent, \"b g (h w) d -> b g h w d\", h=16, w=16)\n",
" latent = rearrange(latent, \"b g h w d -> b (g d) h w\")\n",
Expand All @@ -1037,20 +1031,24 @@
" def configure_optimizers(self):\n",
" optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)\n",
" return optimizer\n",
" \n",
"\n",
" def train_dataloader(self):\n",
" return self.datamodule.train_dataloader()\n",
"\n",
" def val_dataloader(self):\n",
" return self.datamodule.val_dataloader()\n",
"\n",
"\n",
"dm = ClayDataModule(data_dir=data_dir, batch_size=2)\n",
"dm.setup()\n",
"#val_dl = iter(dm.val_dataloader())\n",
"# val_dl = iter(dm.val_dataloader())\n",
"\n",
"model_unet = UNet(13, 1)\n",
"model_clay = CLAYModule.load_from_checkpoint(\"/Users/lillythomas/Documents/work/clay/lt/benchmark/seg/clay-small-70MT-1100T-10E.ckpt\", mask_ratio=0.)\n",
"model_clay.eval();\n",
"model_clay = CLAYModule.load_from_checkpoint(\n",
" \"/Users/lillythomas/Documents/work/clay/lt/benchmark/seg/clay-small-70MT-1100T-10E.ckpt\",\n",
" mask_ratio=0.0,\n",
")\n",
"model_clay.eval()\n",
"segmentation_model = SegmentationModel(model_unet, dm)\n",
"\n",
"trainer = Trainer(max_epochs=3)\n",
Expand Down Expand Up @@ -1133,19 +1131,19 @@
" model.eval()\n",
" with torch.no_grad():\n",
" for batch in dataloader:\n",
" #print(batch)\n",
" # print(batch)\n",
" x_val, y_val = batch[\"pixels\"], batch[\"labels\"]\n",
" \n",
"\n",
" emb = model_clay.model.encoder(batch)\n",
" embeddings = emb[0]\n",
" embeddings = embeddings[:,:-2,:]\n",
" embeddings = embeddings[:, :-2, :]\n",
" latent = rearrange(embeddings, \"b (g l) d -> b g l d\", g=6)\n",
" latent = rearrange(latent, \"b g (h w) d -> b g h w d\", h=16, w=16)\n",
" latent = rearrange(latent, \"b g h w d -> b (g d) h w\")\n",
" #y = y[0, :, :, :]\n",
" # y = y[0, :, :, :]\n",
" prediction = model(latent)\n",
" #y_pred = model(batch)\n",
" #_, prediction = torch.max(y_pred, dim=1)\n",
" # y_pred = model(batch)\n",
" # _, prediction = torch.max(y_pred, dim=1)\n",
"\n",
" # Convert tensors to NumPy arrays for plotting\n",
" x_val_np = x_val.cpu().numpy()\n",
Expand All @@ -1170,7 +1168,7 @@
" )\n",
" ),\n",
" (1, 2, 0),\n",
" ) #.clip(0, 6000) / 6000\n",
" ) # .clip(0, 6000) / 6000\n",
" ) # x_val_np[i], (1, 2, 0))) # Plot input images\n",
" axes[i, 1].imshow(\n",
" np.transpose(y_val_np[i], (1, 2, 0))\n",
Expand All @@ -1179,10 +1177,14 @@
"\n",
" plt.show()\n",
"\n",
"model_clay = CLAYModule.load_from_checkpoint(\"/Users/lillythomas/Documents/work/clay/lt/benchmark/seg/clay-small-70MT-1100T-10E.ckpt\", mask_ratio=0.)\n",
"\n",
"model_clay = CLAYModule.load_from_checkpoint(\n",
" \"/Users/lillythomas/Documents/work/clay/lt/benchmark/seg/clay-small-70MT-1100T-10E.ckpt\",\n",
" mask_ratio=0.0,\n",
")\n",
"model_clay.eval()\n",
"#val_dataloader = datamodule_floods.val_dataloader()\n",
"plot_predictions(model_unet, model_clay, dm.val_dataloader())\n"
"# val_dataloader = datamodule_floods.val_dataloader()\n",
"plot_predictions(model_unet, model_clay, dm.val_dataloader())"
]
},
{
Expand Down

0 comments on commit e65b2b0

Please sign in to comment.